cudf-polars-cu13 25.12.0__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. cudf_polars/GIT_COMMIT +1 -1
  2. cudf_polars/VERSION +1 -1
  3. cudf_polars/callback.py +28 -7
  4. cudf_polars/containers/column.py +51 -26
  5. cudf_polars/dsl/expressions/binaryop.py +1 -1
  6. cudf_polars/dsl/expressions/boolean.py +1 -1
  7. cudf_polars/dsl/expressions/selection.py +1 -1
  8. cudf_polars/dsl/expressions/string.py +29 -20
  9. cudf_polars/dsl/expressions/ternary.py +25 -1
  10. cudf_polars/dsl/expressions/unary.py +11 -8
  11. cudf_polars/dsl/ir.py +351 -281
  12. cudf_polars/dsl/translate.py +18 -15
  13. cudf_polars/dsl/utils/aggregations.py +10 -5
  14. cudf_polars/experimental/base.py +10 -0
  15. cudf_polars/experimental/benchmarks/pdsh.py +1 -1
  16. cudf_polars/experimental/benchmarks/utils.py +83 -2
  17. cudf_polars/experimental/distinct.py +2 -0
  18. cudf_polars/experimental/explain.py +1 -1
  19. cudf_polars/experimental/expressions.py +8 -5
  20. cudf_polars/experimental/groupby.py +2 -0
  21. cudf_polars/experimental/io.py +64 -42
  22. cudf_polars/experimental/join.py +15 -2
  23. cudf_polars/experimental/parallel.py +10 -7
  24. cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
  25. cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
  26. cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
  27. cudf_polars/experimental/rapidsmpf/{shuffle.py → collectives/shuffle.py} +90 -114
  28. cudf_polars/experimental/rapidsmpf/core.py +194 -67
  29. cudf_polars/experimental/rapidsmpf/dask.py +172 -0
  30. cudf_polars/experimental/rapidsmpf/dispatch.py +6 -3
  31. cudf_polars/experimental/rapidsmpf/io.py +162 -70
  32. cudf_polars/experimental/rapidsmpf/join.py +162 -77
  33. cudf_polars/experimental/rapidsmpf/nodes.py +421 -180
  34. cudf_polars/experimental/rapidsmpf/repartition.py +130 -65
  35. cudf_polars/experimental/rapidsmpf/union.py +24 -5
  36. cudf_polars/experimental/rapidsmpf/utils.py +228 -16
  37. cudf_polars/experimental/shuffle.py +18 -4
  38. cudf_polars/experimental/sort.py +13 -6
  39. cudf_polars/experimental/spilling.py +1 -1
  40. cudf_polars/testing/plugin.py +6 -3
  41. cudf_polars/utils/config.py +67 -0
  42. cudf_polars/utils/versions.py +3 -3
  43. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/METADATA +9 -10
  44. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/RECORD +47 -43
  45. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
  46. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
  47. {cudf_polars_cu13-25.12.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  """Translate polars IR representation to ours."""
@@ -102,7 +102,7 @@ class Translator:
102
102
  # IR is versioned with major.minor, minor is bumped for backwards
103
103
  # compatible changes (e.g. adding new nodes), major is bumped for
104
104
  # incompatible changes (e.g. renaming nodes).
105
- if (version := self.visitor.version()) >= (10, 1):
105
+ if (version := self.visitor.version()) >= (11, 1):
106
106
  e = NotImplementedError(
107
107
  f"No support for polars IR {version=}"
108
108
  ) # pragma: no cover; no such version for now.
@@ -379,12 +379,12 @@ def _align_decimal_scales(
379
379
  if (
380
380
  left_type.id() != target.id() or left_type.scale() != target.scale()
381
381
  ): # pragma: no cover; no test yet
382
- left = expr.Cast(target, left)
382
+ left = expr.Cast(target, True, left) # noqa: FBT003
383
383
 
384
384
  if (
385
385
  right_type.id() != target.id() or right_type.scale() != target.scale()
386
386
  ): # pragma: no cover; no test yet
387
- right = expr.Cast(target, right)
387
+ right = expr.Cast(target, True, right) # noqa: FBT003
388
388
 
389
389
  return left, right
390
390
 
@@ -746,7 +746,7 @@ def _(
746
746
  *(translator.translate_expr(n=n, schema=schema) for n in node.input),
747
747
  )
748
748
  if name in needs_cast:
749
- return expr.Cast(dtype, result_expr)
749
+ return expr.Cast(dtype, True, result_expr) # noqa: FBT003
750
750
  return result_expr
751
751
  elif not POLARS_VERSION_LT_131 and isinstance(
752
752
  name, plrs._expr_nodes.StructFunction
@@ -787,6 +787,7 @@ def _(
787
787
  if not POLARS_VERSION_LT_134
788
788
  else expr.Cast(
789
789
  DataType(pl.Float64()),
790
+ True, # noqa: FBT003
790
791
  res,
791
792
  )
792
793
  )
@@ -996,6 +997,9 @@ def _(
996
997
  def _(
997
998
  node: plrs._expr_nodes.Cast, translator: Translator, dtype: DataType, schema: Schema
998
999
  ) -> expr.Expr:
1000
+ # TODO: node.options can be 2 meaning wrap_numerical=True
1001
+ # don't necessarily raise because wrapping isn't always needed, but it's unhandled
1002
+ strict = node.options != 1
999
1003
  inner = translator.translate_expr(n=node.expr, schema=schema)
1000
1004
 
1001
1005
  if plc.traits.is_floating_point(inner.dtype.plc_type) and plc.traits.is_fixed_point(
@@ -1003,6 +1007,7 @@ def _(
1003
1007
  ):
1004
1008
  return expr.Cast(
1005
1009
  dtype,
1010
+ strict,
1006
1011
  expr.UnaryFunction(
1007
1012
  inner.dtype, "round", (-dtype.plc_type.scale(), "half_to_even"), inner
1008
1013
  ),
@@ -1011,11 +1016,8 @@ def _(
1011
1016
  # Push casts into literals so we can handle Cast(Literal(Null))
1012
1017
  if isinstance(inner, expr.Literal):
1013
1018
  return inner.astype(dtype)
1014
- elif isinstance(inner, expr.Cast):
1015
- # Translation of Len/Count-agg put in a cast, remove double
1016
- # casts if we have one.
1017
- (inner,) = inner.children
1018
- return expr.Cast(dtype, inner)
1019
+ else:
1020
+ return expr.Cast(dtype, strict, inner)
1019
1021
 
1020
1022
 
1021
1023
  @_translate_expr.register
@@ -1037,7 +1039,7 @@ def _(
1037
1039
 
1038
1040
  if agg_name not in ("count", "n_unique", "mean", "median", "quantile"):
1039
1041
  args = [
1040
- expr.Cast(dtype, arg)
1042
+ expr.Cast(dtype, True, arg) # noqa: FBT003
1041
1043
  if plc.traits.is_fixed_point(arg.dtype.plc_type)
1042
1044
  and arg.dtype.plc_type != dtype.plc_type
1043
1045
  else arg
@@ -1047,7 +1049,7 @@ def _(
1047
1049
  value = expr.Agg(dtype, agg_name, node.options, translator._expr_context, *args)
1048
1050
 
1049
1051
  if agg_name in ("count", "n_unique") and value.dtype.id() != plc.TypeId.INT32:
1050
- return expr.Cast(value.dtype, value)
1052
+ return expr.Cast(value.dtype, True, value) # noqa: FBT003
1051
1053
  return value
1052
1054
 
1053
1055
 
@@ -1088,11 +1090,12 @@ def _(
1088
1090
  f64 = DataType(pl.Float64())
1089
1091
  return expr.Cast(
1090
1092
  dtype,
1093
+ True, # noqa: FBT003
1091
1094
  expr.BinOp(
1092
1095
  f64,
1093
1096
  expr.BinOp._MAPPING[node.op],
1094
- expr.Cast(f64, left),
1095
- expr.Cast(f64, right),
1097
+ expr.Cast(f64, True, left), # noqa: FBT003
1098
+ expr.Cast(f64, True, right), # noqa: FBT003
1096
1099
  ),
1097
1100
  )
1098
1101
 
@@ -1132,5 +1135,5 @@ def _(
1132
1135
  ) -> expr.Expr:
1133
1136
  value = expr.Len(dtype)
1134
1137
  if dtype.id() != plc.TypeId.INT32:
1135
- return expr.Cast(dtype, value)
1138
+ return expr.Cast(dtype, True, value) # noqa: FBT003
1136
1139
  return value # pragma: no cover; never reached since polars len has uint32 dtype
@@ -115,7 +115,7 @@ def decompose_single_agg(
115
115
  # - min/max/dense/ordinal -> IDX_DTYPE (UInt32/UInt64)
116
116
  post_col: expr.Expr = expr.Col(agg.dtype, name)
117
117
  if agg.name == "rank":
118
- post_col = expr.Cast(agg.dtype, post_col)
118
+ post_col = expr.Cast(agg.dtype, True, post_col) # noqa: FBT003
119
119
 
120
120
  return [(named_expr, True)], named_expr.reconstruct(post_col)
121
121
  if isinstance(agg, expr.UnaryFunction) and agg.name == "null_count":
@@ -131,10 +131,10 @@ def decompose_single_agg(
131
131
  sum_name = next(name_generator)
132
132
  sum_agg = expr.NamedExpr(
133
133
  sum_name,
134
- expr.Agg(u32, "sum", (), context, expr.Cast(u32, is_null_bool)),
134
+ expr.Agg(u32, "sum", (), context, expr.Cast(u32, True, is_null_bool)), # noqa: FBT003
135
135
  )
136
136
  return [(sum_agg, True)], named_expr.reconstruct(
137
- expr.Cast(u32, expr.Col(u32, sum_name))
137
+ expr.Cast(u32, True, expr.Col(u32, sum_name)) # noqa: FBT003
138
138
  )
139
139
  if isinstance(agg, expr.Col):
140
140
  # TODO: collect_list produces null for empty group in libcudf, empty list in polars.
@@ -201,6 +201,7 @@ def decompose_single_agg(
201
201
  agg.dtype
202
202
  if plc.traits.is_floating_point(agg.dtype.plc_type)
203
203
  else DataType(pl.Float64()),
204
+ True, # noqa: FBT003
204
205
  child,
205
206
  )
206
207
  child_dtype = child.dtype.plc_type
@@ -229,7 +230,11 @@ def decompose_single_agg(
229
230
 
230
231
  if agg.name == "sum":
231
232
  col = (
232
- expr.Cast(agg.dtype, expr.Col(DataType(pl.datatypes.Int64()), name))
233
+ expr.Cast(
234
+ agg.dtype,
235
+ True, # noqa: FBT003
236
+ expr.Col(DataType(pl.datatypes.Int64()), name),
237
+ )
233
238
  if (
234
239
  plc.traits.is_integral(agg.dtype.plc_type)
235
240
  and agg.dtype.id() != plc.TypeId.INT64
@@ -282,7 +287,7 @@ def decompose_single_agg(
282
287
  ) # libcudf promotes to float64
283
288
  if agg.dtype.plc_type.id() == plc.TypeId.FLOAT32:
284
289
  # Cast back to float32 to match Polars
285
- post_agg_col = expr.Cast(agg.dtype, post_agg_col)
290
+ post_agg_col = expr.Cast(agg.dtype, True, post_agg_col) # noqa: FBT003
286
291
  return [(named_expr, True)], named_expr.reconstruct(post_agg_col)
287
292
  else:
288
293
  return [(named_expr, True)], named_expr.reconstruct(
@@ -115,6 +115,7 @@ class DataSourceInfo:
115
115
  """
116
116
 
117
117
  _unique_stats_columns: set[str]
118
+ _read_columns: set[str]
118
119
 
119
120
  @property
120
121
  def row_count(self) -> ColumnStat[int]: # pragma: no cover
@@ -141,6 +142,10 @@ class DataSourceInfo:
141
142
  """Add a column needing unique-value information."""
142
143
  self._unique_stats_columns.add(column)
143
144
 
145
+ def add_read_column(self, column: str) -> None:
146
+ """Add a column needing to be read."""
147
+ self._read_columns.add(column)
148
+
144
149
 
145
150
  class DataSourcePair(NamedTuple):
146
151
  """Pair of table-source and column-name information."""
@@ -240,6 +245,11 @@ class ColumnSourceInfo:
240
245
  for table_source, column_name in self.table_source_pairs:
241
246
  table_source.add_unique_stats_column(column or column_name)
242
247
 
248
+ def add_read_column(self, column: str | None = None) -> None:
249
+ """Add a column needing to be read."""
250
+ for table_source, column_name in self.table_source_pairs:
251
+ table_source.add_read_column(column or column_name)
252
+
243
253
 
244
254
  class ColumnStats:
245
255
  """
@@ -610,7 +610,7 @@ class PDSHQueries:
610
610
  q1 = (
611
611
  part.filter(pl.col("p_brand") == var1)
612
612
  .filter(pl.col("p_container") == var2)
613
- .join(lineitem, how="left", left_on="p_partkey", right_on="l_partkey")
613
+ .join(lineitem, how="inner", left_on="p_partkey", right_on="l_partkey")
614
614
  )
615
615
 
616
616
  return (
@@ -256,6 +256,8 @@ class RunConfig:
256
256
  query_set: str
257
257
  collect_traces: bool = False
258
258
  stats_planning: bool
259
+ max_io_threads: int
260
+ native_parquet: bool
259
261
 
260
262
  def __post_init__(self) -> None: # noqa: D105
261
263
  if self.gather_shuffle_stats and self.shuffle != "rapidsmpf":
@@ -371,6 +373,8 @@ class RunConfig:
371
373
  query_set=args.query_set,
372
374
  collect_traces=args.collect_traces,
373
375
  stats_planning=args.stats_planning,
376
+ max_io_threads=args.max_io_threads,
377
+ native_parquet=args.native_parquet,
374
378
  )
375
379
 
376
380
  def serialize(self, engine: pl.GPUEngine | None) -> dict:
@@ -400,6 +404,8 @@ class RunConfig:
400
404
  print(f"shuffle_method: {self.shuffle}")
401
405
  print(f"broadcast_join_limit: {self.broadcast_join_limit}")
402
406
  print(f"stats_planning: {self.stats_planning}")
407
+ if self.runtime == "rapidsmpf":
408
+ print(f"native_parquet: {self.native_parquet}")
403
409
  if self.cluster == "distributed":
404
410
  print(f"n_workers: {self.n_workers}")
405
411
  print(f"threads: {self.threads}")
@@ -450,10 +456,16 @@ def get_executor_options(
450
456
  executor_options["rapidsmpf_spill"] = run_config.rapidsmpf_spill
451
457
  if run_config.cluster == "distributed":
452
458
  executor_options["cluster"] = "distributed"
453
- if run_config.stats_planning:
454
- executor_options["stats_planning"] = {"use_reduction_planning": True}
459
+ executor_options["stats_planning"] = {
460
+ "use_reduction_planning": run_config.stats_planning,
461
+ "use_sampling": (
462
+ # Always allow row-group sampling for rapidsmpf runtime
463
+ run_config.stats_planning or run_config.runtime == "rapidsmpf"
464
+ ),
465
+ }
455
466
  executor_options["client_device_threshold"] = run_config.spill_device
456
467
  executor_options["runtime"] = run_config.runtime
468
+ executor_options["max_io_threads"] = run_config.max_io_threads
457
469
 
458
470
  if (
459
471
  benchmark
@@ -879,6 +891,18 @@ def parse_args(
879
891
  default=False,
880
892
  help="Enable statistics planning.",
881
893
  )
894
+ parser.add_argument(
895
+ "--max-io-threads",
896
+ default=2,
897
+ type=int,
898
+ help="Maximum number of IO threads for rapidsmpf runtime.",
899
+ )
900
+ parser.add_argument(
901
+ "--native-parquet",
902
+ action=argparse.BooleanOptionalAction,
903
+ default=True,
904
+ help="Use C++ read_parquet nodes for the rapidsmpf runtime.",
905
+ )
882
906
 
883
907
  parsed_args = parser.parse_args(args)
884
908
 
@@ -908,6 +932,12 @@ def run_polars(
908
932
 
909
933
  if run_config.executor != "cpu":
910
934
  executor_options = get_executor_options(run_config, benchmark=benchmark)
935
+ if run_config.runtime == "rapidsmpf":
936
+ parquet_options = {
937
+ "use_rapidsmpf_native": run_config.native_parquet,
938
+ }
939
+ else:
940
+ parquet_options = {}
911
941
  engine = pl.GPUEngine(
912
942
  raise_on_fail=True,
913
943
  memory_resource=rmm.mr.CudaAsyncMemoryResource()
@@ -916,6 +946,7 @@ def run_polars(
916
946
  cuda_stream_policy=run_config.stream_policy,
917
947
  executor=run_config.executor,
918
948
  executor_options=executor_options,
949
+ parquet_options=parquet_options,
919
950
  )
920
951
 
921
952
  for q_id in run_config.queries:
@@ -1163,6 +1194,45 @@ PDSH_TABLE_NAMES: list[str] = [
1163
1194
  ]
1164
1195
 
1165
1196
 
1197
+ def print_duckdb_plan(
1198
+ q_id: int,
1199
+ sql: str,
1200
+ dataset_path: Path,
1201
+ suffix: str,
1202
+ query_set: str,
1203
+ args: argparse.Namespace,
1204
+ ) -> None:
1205
+ """Print DuckDB query plan using EXPLAIN."""
1206
+ if duckdb is None:
1207
+ raise ImportError(duckdb_err)
1208
+
1209
+ if query_set == "pdsds":
1210
+ tbl_names = PDSDS_TABLE_NAMES
1211
+ else:
1212
+ tbl_names = PDSH_TABLE_NAMES
1213
+
1214
+ with duckdb.connect() as conn:
1215
+ for name in tbl_names:
1216
+ pattern = (Path(dataset_path) / name).as_posix() + suffix
1217
+ conn.execute(
1218
+ f"CREATE OR REPLACE VIEW {name} AS "
1219
+ f"SELECT * FROM parquet_scan('{pattern}');"
1220
+ )
1221
+
1222
+ if args.explain_logical and args.explain:
1223
+ conn.execute("PRAGMA explain_output = 'all';")
1224
+ elif args.explain_logical:
1225
+ conn.execute("PRAGMA explain_output = 'optimized_only';")
1226
+ else:
1227
+ conn.execute("PRAGMA explain_output = 'physical_only';")
1228
+
1229
+ print(f"\nDuckDB Query {q_id} - Plan\n")
1230
+
1231
+ plan_rows = conn.execute(f"EXPLAIN {sql}").fetchall()
1232
+ for _, line in plan_rows:
1233
+ print(line)
1234
+
1235
+
1166
1236
  def execute_duckdb_query(
1167
1237
  query: str,
1168
1238
  dataset_path: Path,
@@ -1203,6 +1273,17 @@ def run_duckdb(
1203
1273
  raise NotImplementedError(f"Query {q_id} not implemented.") from err
1204
1274
 
1205
1275
  sql = get_q(run_config)
1276
+
1277
+ if args.explain or args.explain_logical:
1278
+ print_duckdb_plan(
1279
+ q_id=q_id,
1280
+ sql=sql,
1281
+ dataset_path=run_config.dataset_path,
1282
+ suffix=run_config.suffix,
1283
+ query_set=duckdb_queries_cls.name,
1284
+ args=args,
1285
+ )
1286
+
1206
1287
  print(f"DuckDB Executing: {q_id}")
1207
1288
  records[q_id] = []
1208
1289
 
@@ -97,6 +97,7 @@ def lower_distinct(
97
97
  child.schema,
98
98
  shuffle_keys,
99
99
  config_options.executor.shuffle_method,
100
+ config_options.executor.shuffler_insertion_method,
100
101
  child,
101
102
  )
102
103
  partition_info[child] = PartitionInfo(
@@ -150,6 +151,7 @@ def lower_distinct(
150
151
  new_node.schema,
151
152
  shuffle_keys,
152
153
  config_options.executor.shuffle_method,
154
+ config_options.executor.shuffler_insertion_method,
153
155
  new_node,
154
156
  )
155
157
  partition_info[new_node] = PartitionInfo(count=output_count)
@@ -71,7 +71,7 @@ def explain_query(
71
71
 
72
72
  lowered_ir, partition_info, _ = rapidsmpf_lower_ir_graph(ir, config)
73
73
  else:
74
- lowered_ir, partition_info = lower_ir_graph(ir, config)
74
+ lowered_ir, partition_info, _ = lower_ir_graph(ir, config)
75
75
  return _repr_ir_tree(lowered_ir, partition_info)
76
76
  else:
77
77
  if config.executor.name == "streaming":
@@ -41,7 +41,7 @@ from cudf_polars.dsl.expressions.aggregation import Agg
41
41
  from cudf_polars.dsl.expressions.base import Col, ExecutionContext, Expr, NamedExpr
42
42
  from cudf_polars.dsl.expressions.binaryop import BinOp
43
43
  from cudf_polars.dsl.expressions.literal import Literal
44
- from cudf_polars.dsl.expressions.unary import Cast, UnaryFunction
44
+ from cudf_polars.dsl.expressions.unary import Cast, Len, UnaryFunction
45
45
  from cudf_polars.dsl.ir import IR, Distinct, Empty, HConcat, Select
46
46
  from cudf_polars.dsl.traversal import (
47
47
  CachingVisitor,
@@ -236,7 +236,7 @@ def _decompose_unique(
236
236
 
237
237
 
238
238
  def _decompose_agg_node(
239
- agg: Agg,
239
+ agg: Agg | Len,
240
240
  input_ir: IR,
241
241
  partition_info: MutableMapping[IR, PartitionInfo],
242
242
  config_options: ConfigOptions,
@@ -272,7 +272,7 @@ def _decompose_agg_node(
272
272
  """
273
273
  expr: Expr
274
274
  exprs: list[Expr]
275
- if agg.name == "count":
275
+ if isinstance(agg, Len) or agg.name == "count":
276
276
  # Chunkwise stage
277
277
  columns, input_ir, partition_info = select(
278
278
  [agg],
@@ -350,6 +350,7 @@ def _decompose_agg_node(
350
350
  input_ir.schema,
351
351
  shuffle_on,
352
352
  config_options.executor.shuffle_method,
353
+ config_options.executor.shuffler_insertion_method,
353
354
  input_ir,
354
355
  )
355
356
  partition_info[input_ir] = PartitionInfo(
@@ -359,7 +360,7 @@ def _decompose_agg_node(
359
360
 
360
361
  # Chunkwise stage
361
362
  columns, input_ir, partition_info = select(
362
- [Cast(agg.dtype, agg)],
363
+ [Cast(agg.dtype, True, agg)], # noqa: FBT003
363
364
  input_ir,
364
365
  partition_info,
365
366
  names=names,
@@ -453,7 +454,9 @@ def _decompose_expr_node(
453
454
  if partition_count == 1 or expr.is_pointwise:
454
455
  # Single-partition and pointwise expressions are always supported.
455
456
  return expr, input_ir, partition_info
456
- elif isinstance(expr, Agg) and expr.name in _SUPPORTED_AGGS:
457
+ elif isinstance(expr, Len) or (
458
+ isinstance(expr, Agg) and expr.name in _SUPPORTED_AGGS
459
+ ):
457
460
  # This is a supported Agg expression.
458
461
  return _decompose_agg_node(
459
462
  expr, input_ir, partition_info, config_options, names=names
@@ -249,6 +249,7 @@ def _(
249
249
  child.schema,
250
250
  ir.keys,
251
251
  config_options.executor.shuffle_method,
252
+ config_options.executor.shuffler_insertion_method,
252
253
  child,
253
254
  )
254
255
  partition_info[child] = PartitionInfo(
@@ -291,6 +292,7 @@ def _(
291
292
  gb_pwise.schema,
292
293
  grouped_keys,
293
294
  config_options.executor.shuffle_method,
295
+ config_options.executor.shuffler_insertion_method,
294
296
  gb_pwise,
295
297
  )
296
298
  partition_info[gb_inter] = PartitionInfo(count=post_aggregation_count)
@@ -709,6 +709,8 @@ class ParquetSourceInfo(DataSourceInfo):
709
709
  # Helper attributes
710
710
  self._key_columns: set[str] = set() # Used to fuse lazy row-group sampling
711
711
  self._unique_stats: dict[str, UniqueStats] = {}
712
+ self._read_columns: set[str] = set()
713
+ self._real_rg_size: dict[str, int] = {}
712
714
 
713
715
  @functools.cached_property
714
716
  def metadata(self) -> ParquetMetadata:
@@ -731,11 +733,13 @@ class ParquetSourceInfo(DataSourceInfo):
731
733
  return
732
734
 
733
735
  column_names = self.metadata.column_names
734
- if not (
735
- key_columns := [key for key in self._key_columns if key in column_names]
736
- ): # pragma: no cover; should never get here
737
- # No key columns found in the file
738
- raise ValueError(f"None of {self._key_columns} in {column_names}")
736
+ key_columns = [key for key in self._key_columns if key in column_names]
737
+ read_columns = list(
738
+ self._read_columns.intersection(column_names).union(key_columns)
739
+ )
740
+ if not read_columns: # pragma: no cover; should never get here
741
+ # No key columns or read columns found in the file
742
+ raise ValueError(f"None of {read_columns} in {column_names}")
739
743
 
740
744
  sampled_file_count = len(sample_paths)
741
745
  num_row_groups_per_file = self.metadata.num_row_groups_per_file
@@ -745,15 +749,15 @@ class ParquetSourceInfo(DataSourceInfo):
745
749
  ):
746
750
  raise ValueError("Parquet metadata sampling failed.") # pragma: no cover
747
751
 
748
- n = 0
752
+ n_sampled = 0
749
753
  samples: defaultdict[str, list[int]] = defaultdict(list)
750
754
  for path, num_rgs in zip(sample_paths, num_row_groups_per_file, strict=True):
751
755
  for rg_id in range(num_rgs):
752
- n += 1
756
+ n_sampled += 1
753
757
  samples[path].append(rg_id)
754
- if n == self.max_row_group_samples:
758
+ if n_sampled == self.max_row_group_samples:
755
759
  break
756
- if n == self.max_row_group_samples:
760
+ if n_sampled == self.max_row_group_samples:
757
761
  break
758
762
 
759
763
  exact = sampled_file_count == len(
@@ -763,7 +767,7 @@ class ParquetSourceInfo(DataSourceInfo):
763
767
  options = plc.io.parquet.ParquetReaderOptions.builder(
764
768
  plc.io.SourceInfo(list(samples))
765
769
  ).build()
766
- options.set_columns(key_columns)
770
+ options.set_columns(read_columns)
767
771
  options.set_row_groups(list(samples.values()))
768
772
  stream = get_cuda_stream()
769
773
  tbl_w_meta = plc.io.parquet.read_parquet(options, stream=stream)
@@ -773,30 +777,32 @@ class ParquetSourceInfo(DataSourceInfo):
773
777
  tbl_w_meta.columns,
774
778
  strict=True,
775
779
  ):
776
- row_group_unique_count = plc.stream_compaction.distinct_count(
777
- column,
778
- plc.types.NullPolicy.INCLUDE,
779
- plc.types.NanPolicy.NAN_IS_NULL,
780
- stream=stream,
781
- )
782
- fraction = row_group_unique_count / row_group_num_rows
783
- # Assume that if every row is unique then this is a
784
- # primary key otherwise it's a foreign key and we
785
- # can't use the single row group count estimate.
786
- # Example, consider a "foreign" key that has 100
787
- # unique values. If we sample from a single row group,
788
- # we likely obtain a unique count of 100. But we can't
789
- # necessarily deduce that that means that the unique
790
- # count is 100 / num_rows_in_group * num_rows_in_file
791
- count: int | None = None
792
- if exact:
793
- count = row_group_unique_count
794
- elif row_group_unique_count == row_group_num_rows:
795
- count = self.row_count.value
796
- self._unique_stats[name] = UniqueStats(
797
- ColumnStat[int](value=count, exact=exact),
798
- ColumnStat[float](value=fraction, exact=exact),
799
- )
780
+ self._real_rg_size[name] = column.device_buffer_size() // n_sampled
781
+ if name in key_columns:
782
+ row_group_unique_count = plc.stream_compaction.distinct_count(
783
+ column,
784
+ plc.types.NullPolicy.INCLUDE,
785
+ plc.types.NanPolicy.NAN_IS_NULL,
786
+ stream=stream,
787
+ )
788
+ fraction = row_group_unique_count / row_group_num_rows
789
+ # Assume that if every row is unique then this is a
790
+ # primary key otherwise it's a foreign key and we
791
+ # can't use the single row group count estimate.
792
+ # Example, consider a "foreign" key that has 100
793
+ # unique values. If we sample from a single row group,
794
+ # we likely obtain a unique count of 100. But we can't
795
+ # necessarily deduce that that means that the unique
796
+ # count is 100 / num_rows_in_group * num_rows_in_file
797
+ count: int | None = None
798
+ if exact:
799
+ count = row_group_unique_count
800
+ elif row_group_unique_count == row_group_num_rows:
801
+ count = self.row_count.value
802
+ self._unique_stats[name] = UniqueStats(
803
+ ColumnStat[int](value=count, exact=exact),
804
+ ColumnStat[float](value=fraction, exact=exact),
805
+ )
800
806
  stream.synchronize()
801
807
 
802
808
  def _update_unique_stats(self, column: str) -> None:
@@ -822,6 +828,15 @@ class ParquetSourceInfo(DataSourceInfo):
822
828
  # the row count, because dictionary encoding can make the
823
829
  # in-memory size much larger.
824
830
  min_value = max(1, row_count // file_count)
831
+ if partial_mean_size < min_value and column not in self._real_rg_size:
832
+ # If the metadata is suspiciously small,
833
+ # sample "real" data to get a better estimate.
834
+ self._sample_row_groups()
835
+ if column in self._real_rg_size:
836
+ partial_mean_size = int(
837
+ self._real_rg_size[column]
838
+ * statistics.mean(self.metadata.num_row_groups_per_file)
839
+ )
825
840
  return ColumnStat[int](max(min_value, partial_mean_size))
826
841
  return ColumnStat[int]()
827
842
 
@@ -863,14 +878,19 @@ def _extract_scan_stats(
863
878
  config_options.parquet_options.max_row_group_samples,
864
879
  config_options.executor.stats_planning,
865
880
  )
866
- return {
881
+ cstats = {
867
882
  name: ColumnStats(
868
883
  name=name,
869
884
  source_info=ColumnSourceInfo(DataSourcePair(table_source_info, name)),
870
885
  )
871
886
  for name in ir.schema
872
887
  }
873
-
888
+ # Mark all columns that we are reading in case
889
+ # we need to sample real data later.
890
+ if config_options.executor.stats_planning.use_sampling:
891
+ for name, cs in cstats.items():
892
+ cs.source_info.add_read_column(name)
893
+ return cstats
874
894
  else:
875
895
  return {name: ColumnStats(name=name) for name in ir.schema}
876
896
 
@@ -889,10 +909,10 @@ class DataFrameSourceInfo(DataSourceInfo):
889
909
 
890
910
  def __init__(
891
911
  self,
892
- df: Any,
912
+ df: pl.DataFrame,
893
913
  stats_planning: StatsPlanningOptions,
894
914
  ):
895
- self._df = df
915
+ self._pdf = df
896
916
  self._stats_planning = stats_planning
897
917
  self._key_columns: set[str] = set()
898
918
  self._unique_stats_columns = set()
@@ -901,17 +921,19 @@ class DataFrameSourceInfo(DataSourceInfo):
901
921
  @functools.cached_property
902
922
  def row_count(self) -> ColumnStat[int]:
903
923
  """Data source row-count estimate."""
904
- return ColumnStat[int](value=self._df.height(), exact=True)
924
+ return ColumnStat[int](value=self._pdf.height, exact=True)
905
925
 
906
926
  def _update_unique_stats(self, column: str) -> None:
907
927
  if column not in self._unique_stats and self._stats_planning.use_sampling:
908
928
  row_count = self.row_count.value
909
929
  try:
910
930
  unique_count = (
911
- self._df.get_column(column).approx_n_unique() if row_count else 0
931
+ self._pdf._df.get_column(column).approx_n_unique()
932
+ if row_count
933
+ else 0
912
934
  )
913
935
  except pl.exceptions.InvalidOperationError: # pragma: no cover
914
- unique_count = self._df.get_column(column).n_unique()
936
+ unique_count = self._pdf._df.get_column(column).n_unique()
915
937
  unique_fraction = min((unique_count / row_count), 1.0) if row_count else 1.0
916
938
  self._unique_stats[column] = UniqueStats(
917
939
  ColumnStat[int](value=unique_count),
@@ -932,7 +954,7 @@ def _extract_dataframescan_stats(
932
954
  "Only streaming executor is supported in _extract_dataframescan_stats"
933
955
  )
934
956
  table_source_info = DataFrameSourceInfo(
935
- ir.df,
957
+ pl.DataFrame._from_pydf(ir.df),
936
958
  config_options.executor.stats_planning,
937
959
  )
938
960
  return {