cudf-polars-cu13 25.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. cudf_polars/GIT_COMMIT +1 -0
  2. cudf_polars/VERSION +1 -0
  3. cudf_polars/__init__.py +28 -0
  4. cudf_polars/_version.py +21 -0
  5. cudf_polars/callback.py +318 -0
  6. cudf_polars/containers/__init__.py +13 -0
  7. cudf_polars/containers/column.py +495 -0
  8. cudf_polars/containers/dataframe.py +361 -0
  9. cudf_polars/containers/datatype.py +137 -0
  10. cudf_polars/dsl/__init__.py +8 -0
  11. cudf_polars/dsl/expr.py +66 -0
  12. cudf_polars/dsl/expressions/__init__.py +8 -0
  13. cudf_polars/dsl/expressions/aggregation.py +226 -0
  14. cudf_polars/dsl/expressions/base.py +272 -0
  15. cudf_polars/dsl/expressions/binaryop.py +120 -0
  16. cudf_polars/dsl/expressions/boolean.py +326 -0
  17. cudf_polars/dsl/expressions/datetime.py +271 -0
  18. cudf_polars/dsl/expressions/literal.py +97 -0
  19. cudf_polars/dsl/expressions/rolling.py +643 -0
  20. cudf_polars/dsl/expressions/selection.py +74 -0
  21. cudf_polars/dsl/expressions/slicing.py +46 -0
  22. cudf_polars/dsl/expressions/sorting.py +85 -0
  23. cudf_polars/dsl/expressions/string.py +1002 -0
  24. cudf_polars/dsl/expressions/struct.py +137 -0
  25. cudf_polars/dsl/expressions/ternary.py +49 -0
  26. cudf_polars/dsl/expressions/unary.py +517 -0
  27. cudf_polars/dsl/ir.py +2607 -0
  28. cudf_polars/dsl/nodebase.py +164 -0
  29. cudf_polars/dsl/to_ast.py +359 -0
  30. cudf_polars/dsl/tracing.py +16 -0
  31. cudf_polars/dsl/translate.py +939 -0
  32. cudf_polars/dsl/traversal.py +224 -0
  33. cudf_polars/dsl/utils/__init__.py +8 -0
  34. cudf_polars/dsl/utils/aggregations.py +481 -0
  35. cudf_polars/dsl/utils/groupby.py +98 -0
  36. cudf_polars/dsl/utils/naming.py +34 -0
  37. cudf_polars/dsl/utils/replace.py +61 -0
  38. cudf_polars/dsl/utils/reshape.py +74 -0
  39. cudf_polars/dsl/utils/rolling.py +121 -0
  40. cudf_polars/dsl/utils/windows.py +192 -0
  41. cudf_polars/experimental/__init__.py +8 -0
  42. cudf_polars/experimental/base.py +386 -0
  43. cudf_polars/experimental/benchmarks/__init__.py +4 -0
  44. cudf_polars/experimental/benchmarks/pdsds.py +220 -0
  45. cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
  46. cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
  47. cudf_polars/experimental/benchmarks/pdsds_queries/q10.py +225 -0
  48. cudf_polars/experimental/benchmarks/pdsds_queries/q2.py +244 -0
  49. cudf_polars/experimental/benchmarks/pdsds_queries/q3.py +65 -0
  50. cudf_polars/experimental/benchmarks/pdsds_queries/q4.py +359 -0
  51. cudf_polars/experimental/benchmarks/pdsds_queries/q5.py +462 -0
  52. cudf_polars/experimental/benchmarks/pdsds_queries/q6.py +92 -0
  53. cudf_polars/experimental/benchmarks/pdsds_queries/q7.py +79 -0
  54. cudf_polars/experimental/benchmarks/pdsds_queries/q8.py +524 -0
  55. cudf_polars/experimental/benchmarks/pdsds_queries/q9.py +137 -0
  56. cudf_polars/experimental/benchmarks/pdsh.py +814 -0
  57. cudf_polars/experimental/benchmarks/utils.py +832 -0
  58. cudf_polars/experimental/dask_registers.py +200 -0
  59. cudf_polars/experimental/dispatch.py +156 -0
  60. cudf_polars/experimental/distinct.py +197 -0
  61. cudf_polars/experimental/explain.py +157 -0
  62. cudf_polars/experimental/expressions.py +590 -0
  63. cudf_polars/experimental/groupby.py +327 -0
  64. cudf_polars/experimental/io.py +943 -0
  65. cudf_polars/experimental/join.py +391 -0
  66. cudf_polars/experimental/parallel.py +423 -0
  67. cudf_polars/experimental/repartition.py +69 -0
  68. cudf_polars/experimental/scheduler.py +155 -0
  69. cudf_polars/experimental/select.py +188 -0
  70. cudf_polars/experimental/shuffle.py +354 -0
  71. cudf_polars/experimental/sort.py +609 -0
  72. cudf_polars/experimental/spilling.py +151 -0
  73. cudf_polars/experimental/statistics.py +795 -0
  74. cudf_polars/experimental/utils.py +169 -0
  75. cudf_polars/py.typed +0 -0
  76. cudf_polars/testing/__init__.py +8 -0
  77. cudf_polars/testing/asserts.py +448 -0
  78. cudf_polars/testing/io.py +122 -0
  79. cudf_polars/testing/plugin.py +236 -0
  80. cudf_polars/typing/__init__.py +219 -0
  81. cudf_polars/utils/__init__.py +8 -0
  82. cudf_polars/utils/config.py +741 -0
  83. cudf_polars/utils/conversion.py +40 -0
  84. cudf_polars/utils/dtypes.py +118 -0
  85. cudf_polars/utils/sorting.py +53 -0
  86. cudf_polars/utils/timer.py +39 -0
  87. cudf_polars/utils/versions.py +27 -0
  88. cudf_polars_cu13-25.10.0.dist-info/METADATA +136 -0
  89. cudf_polars_cu13-25.10.0.dist-info/RECORD +92 -0
  90. cudf_polars_cu13-25.10.0.dist-info/WHEEL +5 -0
  91. cudf_polars_cu13-25.10.0.dist-info/licenses/LICENSE +201 -0
  92. cudf_polars_cu13-25.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,795 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Utilities for tracking column statistics."""
5
+
6
+ from __future__ import annotations
7
+
8
+ import itertools
9
+ from typing import TYPE_CHECKING, TypeVar
10
+
11
+ from cudf_polars.dsl.expr import Agg, UnaryFunction
12
+ from cudf_polars.dsl.ir import (
13
+ IR,
14
+ DataFrameScan,
15
+ Distinct,
16
+ Filter,
17
+ GroupBy,
18
+ HConcat,
19
+ Join,
20
+ Scan,
21
+ Select,
22
+ Sort,
23
+ Union,
24
+ )
25
+ from cudf_polars.dsl.traversal import post_traversal, traversal
26
+ from cudf_polars.experimental.base import (
27
+ ColumnSourceInfo,
28
+ ColumnStat,
29
+ ColumnStats,
30
+ JoinKey,
31
+ StatsCollector,
32
+ )
33
+ from cudf_polars.experimental.dispatch import (
34
+ initialize_column_stats,
35
+ update_column_stats,
36
+ )
37
+ from cudf_polars.experimental.expressions import _SUPPORTED_AGGS
38
+ from cudf_polars.experimental.utils import _leaf_column_names
39
+ from cudf_polars.utils import conversion
40
+
41
+ if TYPE_CHECKING:
42
+ from collections.abc import Mapping, Sequence
43
+
44
+ from cudf_polars.dsl.expr import Expr
45
+ from cudf_polars.experimental.base import JoinInfo
46
+ from cudf_polars.typing import Slice as Zlice
47
+ from cudf_polars.utils.config import ConfigOptions, StatsPlanningOptions
48
+
49
+
50
+ def collect_statistics(root: IR, config_options: ConfigOptions) -> StatsCollector:
51
+ """
52
+ Collect column statistics for a query.
53
+
54
+ Parameters
55
+ ----------
56
+ root
57
+ Root IR node for collecting column statistics.
58
+ config_options
59
+ GPUEngine configuration options.
60
+
61
+ Returns
62
+ -------
63
+ A StatsCollector object with populated column statistics.
64
+ """
65
+ assert config_options.executor.name == "streaming", (
66
+ "Only streaming executor is supported in collect_statistics"
67
+ )
68
+ stats_planning = config_options.executor.stats_planning
69
+ need_local_statistics = using_local_statistics(stats_planning)
70
+ if need_local_statistics or stats_planning.use_io_partitioning:
71
+ # Start with base statistics.
72
+ # Here we build an outline of the statistics that will be
73
+ # collected before any real data is sampled. We will not
74
+ # read any Parquet metadata or sample any unique-value
75
+ # statistics during this step.
76
+ # (That said, Polars does it's own metadata sampling
77
+ # before we ever get the logical plan in cudf-polars)
78
+ stats = collect_base_stats(root, config_options)
79
+
80
+ # Avoid collecting local statistics unless we are using them.
81
+ if need_local_statistics:
82
+ # Apply PK-FK heuristics.
83
+ # Here we use PK-FK heuristics to estimate the unique count
84
+ # for each join key. We will not do any unique-value sampling
85
+ # during this step. However, we will use Parquet metadata to
86
+ # estimate the row-count for each table source. This metadata
87
+ # is cached in the DataSourceInfo object for each table.
88
+ if stats_planning.use_join_heuristics:
89
+ apply_pkfk_heuristics(stats.join_info)
90
+
91
+ # Update statistics for each node.
92
+ # Here we set local row-count and unique-value statistics
93
+ # on each node in the IR graph. We DO perform unique-value
94
+ # sampling during this step. However, we only sample columns
95
+ # that have been marked as needing unique-value statistics
96
+ # during the `collect_base_stats` step. We always sample ALL
97
+ # "marked" columns within the same table source at once.
98
+ for node in post_traversal([root]):
99
+ update_column_stats(node, stats, config_options)
100
+
101
+ return stats
102
+
103
+ return StatsCollector()
104
+
105
+
106
+ def collect_base_stats(root: IR, config_options: ConfigOptions) -> StatsCollector:
107
+ """
108
+ Collect base datasource statistics.
109
+
110
+ Parameters
111
+ ----------
112
+ root
113
+ Root IR node for collecting base datasource statistics.
114
+ config_options
115
+ GPUEngine configuration options.
116
+
117
+ Returns
118
+ -------
119
+ A new StatsCollector object with populated datasource statistics.
120
+
121
+ Notes
122
+ -----
123
+ This function initializes the ``StatsCollector`` object
124
+ with the base datasource statistics. The goal is to build an
125
+ outline of the statistics that will be collected before any
126
+ real data is sampled.
127
+ """
128
+ assert config_options.executor.name == "streaming", (
129
+ "Only streaming executor is supported in collect_statistics"
130
+ )
131
+ stats_planning = config_options.executor.stats_planning
132
+ need_local_statistics = using_local_statistics(stats_planning)
133
+ need_join_info = need_local_statistics and stats_planning.use_join_heuristics
134
+
135
+ stats: StatsCollector = StatsCollector()
136
+ for node in post_traversal([root]):
137
+ # Initialize column statistics from datasource information
138
+ if need_local_statistics or (
139
+ stats_planning.use_io_partitioning
140
+ and isinstance(node, (Scan, DataFrameScan))
141
+ ):
142
+ stats.column_stats[node] = initialize_column_stats(
143
+ node, stats, config_options
144
+ )
145
+ # Initialize join information
146
+ if need_join_info and isinstance(node, Join):
147
+ initialize_join_info(node, stats)
148
+ return stats
149
+
150
+
151
+ def using_local_statistics(stats_planning: StatsPlanningOptions) -> bool:
152
+ """
153
+ Check if we are using local statistics for query planning.
154
+
155
+ Notes
156
+ -----
157
+ This function is used to check if we are using local statistics
158
+ for query-planning purposes. For now, this only returns True
159
+ when `use_reduction_planning=True`. We do not consider `use_io_partitioning`
160
+ here because it only depends on datasource statistics.
161
+ """
162
+ return stats_planning.use_reduction_planning
163
+
164
+
165
+ def initialize_join_info(node: Join, stats: StatsCollector) -> None:
166
+ """
167
+ Initialize join information for the given node.
168
+
169
+ Parameters
170
+ ----------
171
+ node
172
+ Join node to initialize join-key information for.
173
+ stats
174
+ StatsCollector object to update.
175
+
176
+ Notes
177
+ -----
178
+ This function updates ``stats.join_info``.
179
+ """
180
+ left, right = node.children
181
+ join_info = stats.join_info
182
+ right_keys = [stats.column_stats[right][n.name] for n in node.right_on]
183
+ left_keys = [stats.column_stats[left][n.name] for n in node.left_on]
184
+ lkey = JoinKey(*right_keys)
185
+ rkey = JoinKey(*left_keys)
186
+ join_info.key_map[lkey].add(rkey)
187
+ join_info.key_map[rkey].add(lkey)
188
+ join_info.join_map[node] = [lkey, rkey]
189
+ for u, v in zip(left_keys, right_keys, strict=True):
190
+ join_info.column_map[u].add(v)
191
+ join_info.column_map[v].add(u)
192
+
193
+
194
+ T = TypeVar("T")
195
+
196
+
197
+ def find_equivalence_sets(join_map: Mapping[T, set[T]]) -> list[set[T]]:
198
+ """
199
+ Find equivalence sets in a join-key mapping.
200
+
201
+ Parameters
202
+ ----------
203
+ join_map
204
+ Joined key or column mapping to find equivalence sets in.
205
+
206
+ Returns
207
+ -------
208
+ List of equivalence sets.
209
+
210
+ Notes
211
+ -----
212
+ This function is used by ``apply_pkfk_heuristics``.
213
+ """
214
+ seen = set()
215
+ components = []
216
+ for v in join_map:
217
+ if v not in seen:
218
+ cluster = {v}
219
+ stack = [v]
220
+ while stack:
221
+ node = stack.pop()
222
+ for n in join_map[node]:
223
+ if n not in cluster:
224
+ cluster.add(n)
225
+ stack.append(n)
226
+ components.append(cluster)
227
+ seen.update(cluster)
228
+ return components
229
+
230
+
231
+ def apply_pkfk_heuristics(join_info: JoinInfo) -> None:
232
+ """
233
+ Apply PK-FK unique-count heuristics to join keys.
234
+
235
+ Parameters
236
+ ----------
237
+ join_info
238
+ Join information to apply PK-FK heuristics to.
239
+
240
+ Notes
241
+ -----
242
+ This function modifies the ``JoinKey`` objects being tracked
243
+ in ``StatsCollector.join_info`` using PK-FK heuristics to
244
+ estimate the "implied" unique-value count. This function also
245
+ modifies the inderlying ``ColumnStats`` objects included in
246
+ a join key.
247
+ """
248
+ # This applies the PK-FK matching scheme of
249
+ # https://blobs.duckdb.org/papers/tom-ebergen-msc-thesis-join-order-optimization-with-almost-no-statistics.pdf
250
+ # See section 3.2
251
+ for keys in find_equivalence_sets(join_info.key_map):
252
+ implied_unique_count = max(
253
+ (
254
+ c.implied_unique_count
255
+ for c in keys
256
+ if c.implied_unique_count is not None
257
+ ),
258
+ # Default unique-count estimate is the minimum source row count
259
+ default=min(
260
+ (c.source_row_count for c in keys if c.source_row_count is not None),
261
+ default=None,
262
+ ),
263
+ )
264
+ for key in keys:
265
+ # Update unique-count estimate for each join key
266
+ key.implied_unique_count = implied_unique_count
267
+
268
+ # We separately apply PK-FK heuristics to individual columns so
269
+ # that we can update ColumnStats.source_info.implied_unique_count
270
+ # and use the per-column information elsewhere in the query plan.
271
+ for cols in find_equivalence_sets(join_info.column_map):
272
+ unique_count = max(
273
+ (
274
+ cs.source_info.implied_unique_count.value
275
+ for cs in cols
276
+ if cs.source_info.implied_unique_count.value is not None
277
+ ),
278
+ default=min(
279
+ (
280
+ cs.source_info.row_count.value
281
+ for cs in cols
282
+ if cs.source_info.row_count.value is not None
283
+ ),
284
+ default=None,
285
+ ),
286
+ )
287
+ for cs in cols:
288
+ cs.source_info.implied_unique_count = ColumnStat[int](unique_count)
289
+
290
+
291
+ def _update_unique_stats_columns(
292
+ child_column_stats: dict[str, ColumnStats],
293
+ key_names: Sequence[str],
294
+ ) -> None:
295
+ """Update set of unique-stats columns in datasource."""
296
+ for name in key_names:
297
+ if (column_stats := child_column_stats.get(name)) is not None:
298
+ column_stats.source_info.add_unique_stats_column()
299
+
300
+
301
+ @initialize_column_stats.register(IR)
302
+ def _default_initialize_column_stats(
303
+ ir: IR, stats: StatsCollector, config_options: ConfigOptions
304
+ ) -> dict[str, ColumnStats]:
305
+ # Default `initialize_column_stats` implementation.
306
+ if len(ir.children) == 1:
307
+ (child,) = ir.children
308
+ child_column_stats = stats.column_stats.get(child, {})
309
+ return {
310
+ name: child_column_stats.get(name, ColumnStats(name=name)).new_parent()
311
+ for name in ir.schema
312
+ }
313
+ else: # pragma: no cover
314
+ # Multi-child nodes loose all information by default.
315
+ return {name: ColumnStats(name=name) for name in ir.schema}
316
+
317
+
318
+ @initialize_column_stats.register(Distinct)
319
+ def _(
320
+ ir: Distinct, stats: StatsCollector, config_options: ConfigOptions
321
+ ) -> dict[str, ColumnStats]:
322
+ # Use default initialize_column_stats after updating
323
+ # the known unique-stats columns.
324
+ (child,) = ir.children
325
+ child_column_stats = stats.column_stats.get(child, {})
326
+ key_names = ir.subset or ir.schema
327
+ _update_unique_stats_columns(child_column_stats, list(key_names))
328
+ return _default_initialize_column_stats(ir, stats, config_options)
329
+
330
+
331
+ @initialize_column_stats.register(Join)
332
+ def _(
333
+ ir: Join, stats: StatsCollector, config_options: ConfigOptions
334
+ ) -> dict[str, ColumnStats]:
335
+ # Copy column statistics from both the left and right children.
336
+ # Special cases to consider:
337
+ # - If a column name appears in both sides of the join,
338
+ # we take it from the "primary" column (right for "Right"
339
+ # joins, left for all other joins).
340
+ # - If a column name doesn't appear in either child, it
341
+ # corresponds to a non-"primary" column with a suffix.
342
+
343
+ children, on = ir.children, (ir.left_on, ir.right_on)
344
+ how = ir.options[0]
345
+ suffix = ir.options[3]
346
+ if how == "Right":
347
+ children, on = children[::-1], on[::-1]
348
+ primary, other = children
349
+ primary_child_stats = stats.column_stats.get(primary, {})
350
+ other_child_stats = stats.column_stats.get(other, {})
351
+
352
+ # Build output column statistics
353
+ column_stats: dict[str, ColumnStats] = {}
354
+ for name in ir.schema:
355
+ if name in primary.schema:
356
+ # "Primary" child stats take preference.
357
+ column_stats[name] = primary_child_stats[name].new_parent()
358
+ elif name in other.schema:
359
+ # "Other" column stats apply to everything else.
360
+ column_stats[name] = other_child_stats[name].new_parent()
361
+ else:
362
+ # If the column name was not in either child table,
363
+ # a suffix was added to a column in "other".
364
+ _name = name.removesuffix(suffix)
365
+ column_stats[name] = other_child_stats[_name].new_parent(name=name)
366
+
367
+ # Update children
368
+ for p_key, o_key in zip(*on, strict=True):
369
+ column_stats[p_key.name].children = (
370
+ primary_child_stats[p_key.name],
371
+ other_child_stats[o_key.name],
372
+ )
373
+ # Add key columns to set of unique-stats columns.
374
+ primary_child_stats[p_key.name].source_info.add_unique_stats_column()
375
+ other_child_stats[o_key.name].source_info.add_unique_stats_column()
376
+
377
+ return column_stats
378
+
379
+
380
+ @initialize_column_stats.register(GroupBy)
381
+ def _(
382
+ ir: GroupBy, stats: StatsCollector, config_options: ConfigOptions
383
+ ) -> dict[str, ColumnStats]:
384
+ (child,) = ir.children
385
+ child_column_stats = stats.column_stats.get(child, {})
386
+
387
+ # Update set of source columns we may lazily sample
388
+ _update_unique_stats_columns(child_column_stats, [n.name for n in ir.keys])
389
+ return _default_initialize_column_stats(ir, stats, config_options)
390
+
391
+
392
+ @initialize_column_stats.register(HConcat)
393
+ def _(
394
+ ir: HConcat, stats: StatsCollector, config_options: ConfigOptions
395
+ ) -> dict[str, ColumnStats]:
396
+ child_column_stats = dict(
397
+ itertools.chain.from_iterable(
398
+ stats.column_stats.get(c, {}).items() for c in ir.children
399
+ )
400
+ )
401
+ return {
402
+ name: child_column_stats.get(name, ColumnStats(name=name)).new_parent()
403
+ for name in ir.schema
404
+ }
405
+
406
+
407
+ @initialize_column_stats.register(Union)
408
+ def _(
409
+ ir: IR, stats: StatsCollector, config_options: ConfigOptions
410
+ ) -> dict[str, ColumnStats]:
411
+ return {
412
+ name: ColumnStats(
413
+ name=name,
414
+ children=tuple(stats.column_stats[child][name] for child in ir.children),
415
+ source_info=ColumnSourceInfo(
416
+ *itertools.chain.from_iterable(
417
+ stats.column_stats[child][name].source_info.table_source_pairs
418
+ for child in ir.children
419
+ )
420
+ ),
421
+ )
422
+ for name in ir.schema
423
+ }
424
+
425
+
426
+ @initialize_column_stats.register(Scan)
427
+ def _(
428
+ ir: Scan, stats: StatsCollector, config_options: ConfigOptions
429
+ ) -> dict[str, ColumnStats]:
430
+ from cudf_polars.experimental.io import _extract_scan_stats
431
+
432
+ return _extract_scan_stats(ir, config_options)
433
+
434
+
435
+ @initialize_column_stats.register(DataFrameScan)
436
+ def _(
437
+ ir: DataFrameScan, stats: StatsCollector, config_options: ConfigOptions
438
+ ) -> dict[str, ColumnStats]:
439
+ from cudf_polars.experimental.io import _extract_dataframescan_stats
440
+
441
+ return _extract_dataframescan_stats(ir, config_options)
442
+
443
+
444
+ @initialize_column_stats.register(Select)
445
+ def _(
446
+ ir: Select, stats: StatsCollector, config_options: ConfigOptions
447
+ ) -> dict[str, ColumnStats]:
448
+ (child,) = ir.children
449
+ column_stats: dict[str, ColumnStats] = {}
450
+ unique_stats_columns: list[str] = []
451
+ child_column_stats = stats.column_stats.get(child, {})
452
+ for ne in ir.exprs:
453
+ if leaf_columns := _leaf_column_names(ne.value):
454
+ # New column is based on 1+ child columns.
455
+ # Inherit the source information from the child columns.
456
+ children = tuple(
457
+ child_column_stats.get(col, ColumnStats(name=col))
458
+ for col in leaf_columns
459
+ )
460
+ column_stats[ne.name] = ColumnStats(
461
+ name=ne.name,
462
+ children=children,
463
+ source_info=ColumnSourceInfo(
464
+ *itertools.chain.from_iterable(
465
+ cs.source_info.table_source_pairs for cs in children
466
+ )
467
+ ),
468
+ )
469
+ else: # pragma: no cover
470
+ # New column is based on 0 child columns.
471
+ # We don't have any source information to inherit.
472
+ # TODO: Do something smart for a Literal source?
473
+ column_stats[ne.name] = ColumnStats(name=ne.name)
474
+
475
+ if any(
476
+ isinstance(expr, UnaryFunction) and expr.name == "unique"
477
+ for expr in traversal([ne.value])
478
+ ):
479
+ # Make sure the leaf column is marked as a unique-stats column.
480
+ unique_stats_columns.extend(list(leaf_columns))
481
+
482
+ if unique_stats_columns:
483
+ _update_unique_stats_columns(stats.column_stats[child], unique_stats_columns)
484
+
485
+ return column_stats
486
+
487
+
488
+ def known_child_row_counts(ir: IR, stats: StatsCollector) -> list[int]:
489
+ """
490
+ Get all non-null row-count estimates for the children of and IR node.
491
+
492
+ Parameters
493
+ ----------
494
+ ir
495
+ IR node to get non-null row-count estimates for.
496
+ stats
497
+ StatsCollector object to get row-count estimates from.
498
+
499
+ Returns
500
+ -------
501
+ List of non-null row-count estimates for all children.
502
+ """
503
+ return [
504
+ value
505
+ for child in ir.children
506
+ if (value := stats.row_count[child].value) is not None
507
+ ]
508
+
509
+
510
+ def apply_slice(num_rows: int, zlice: Zlice | None) -> int:
511
+ """Apply a slice to a row-count estimate."""
512
+ if zlice is None:
513
+ return num_rows
514
+ s, e = conversion.from_polars_slice(zlice, num_rows=num_rows)
515
+ return e - s
516
+
517
+
518
+ def apply_predicate_selectivity(
519
+ ir: IR,
520
+ stats: StatsCollector,
521
+ predicate: Expr,
522
+ config_options: ConfigOptions,
523
+ ) -> None:
524
+ """
525
+ Apply selectivity to a column statistics.
526
+
527
+ Parameters
528
+ ----------
529
+ ir
530
+ IR node containing a predicate.
531
+ stats
532
+ The StatsCollector object to update.
533
+ predicate
534
+ The predicate expression.
535
+ config_options
536
+ GPUEngine configuration options.
537
+ """
538
+ assert config_options.executor.name == "streaming", (
539
+ "Only streaming executor is supported in update_column_stats"
540
+ )
541
+ # TODO: Use predicate to generate a better selectivity estimate. Default is 0.8
542
+ selectivity = config_options.executor.stats_planning.default_selectivity
543
+ if selectivity < 1.0 and (row_count := stats.row_count[ir].value) is not None:
544
+ row_count = max(1, int(row_count * selectivity))
545
+ stats.row_count[ir] = ColumnStat[int](row_count)
546
+ for column_stats in stats.column_stats[ir].values():
547
+ if (unique_count := column_stats.unique_count.value) is not None:
548
+ column_stats.unique_count = ColumnStat[int](
549
+ min(max(1, int(unique_count * selectivity)), row_count)
550
+ )
551
+
552
+
553
+ def copy_child_unique_counts(column_stats_mapping: dict[str, ColumnStats]) -> None:
554
+ """
555
+ Copy unique-count estimates from children to parent.
556
+
557
+ Parameters
558
+ ----------
559
+ column_stats_mapping
560
+ Mapping of column names to ColumnStats objects.
561
+ """
562
+ for column_stats in column_stats_mapping.values():
563
+ column_stats.unique_count = ColumnStat[int](
564
+ # Assume we get the maximum child unique-count estimate
565
+ max(
566
+ (
567
+ cs.unique_count.value
568
+ for cs in column_stats.children
569
+ if cs.unique_count.value is not None
570
+ ),
571
+ default=None,
572
+ )
573
+ )
574
+
575
+
576
+ @update_column_stats.register(IR)
577
+ def _(ir: IR, stats: StatsCollector, config_options: ConfigOptions) -> None:
578
+ # Default `update_column_stats` implementation.
579
+ # Propagate largest child row-count estimate.
580
+ stats.row_count[ir] = ColumnStat[int](
581
+ max(known_child_row_counts(ir, stats), default=None)
582
+ )
583
+
584
+ # Apply slice if relevant.
585
+ # We can also limit the unique-count estimate to the row-count estimate.
586
+ max_unique_count: int | None = None
587
+ if (value := stats.row_count[ir].value) is not None and isinstance(ir, Sort):
588
+ # Apply slice for IR nodes supporting slice pushdown.
589
+ # TODO: Include types other than Sort.
590
+ max_unique_count = apply_slice(value, ir.zlice)
591
+ stats.row_count[ir] = ColumnStat[int](max_unique_count)
592
+
593
+ for column_stats in stats.column_stats[ir].values():
594
+ column_stats.unique_count = ColumnStat[int](
595
+ max(
596
+ (
597
+ min(value, max_unique_count or value)
598
+ for cs in column_stats.children
599
+ if (value := cs.unique_count.value) is not None
600
+ ),
601
+ default=None,
602
+ )
603
+ )
604
+
605
+ if isinstance(ir, Filter):
606
+ apply_predicate_selectivity(ir, stats, ir.mask.value, config_options)
607
+
608
+
609
+ @update_column_stats.register(DataFrameScan)
610
+ def _(ir: DataFrameScan, stats: StatsCollector, config_options: ConfigOptions) -> None:
611
+ # Use datasource row-count estimate.
612
+ if stats.column_stats[ir]:
613
+ stats.row_count[ir] = next(
614
+ iter(stats.column_stats[ir].values())
615
+ ).source_info.row_count
616
+ else: # pragma: no cover; We always have stats.column_stats[ir]
617
+ stats.row_count[ir] = ColumnStat[int](None)
618
+
619
+ # Update unique-count estimates with sampled statistics
620
+ for column_stats in stats.column_stats[ir].values():
621
+ if column_stats.source_info.implied_unique_count.value is None:
622
+ # We don't have a unique-count estimate, so we need to sample the data.
623
+ source_unique_stats = column_stats.source_info.unique_stats(force=False)
624
+ if source_unique_stats.count.value is not None:
625
+ column_stats.unique_count = source_unique_stats.count
626
+ else:
627
+ column_stats.unique_count = column_stats.source_info.implied_unique_count
628
+
629
+
630
+ @update_column_stats.register(Scan)
631
+ def _(ir: Scan, stats: StatsCollector, config_options: ConfigOptions) -> None:
632
+ # Use datasource row-count estimate.
633
+ if stats.column_stats[ir]:
634
+ stats.row_count[ir] = next(
635
+ iter(stats.column_stats[ir].values())
636
+ ).source_info.row_count
637
+ else: # pragma: no cover; We always have stats.column_stats[ir]
638
+ # No column stats available.
639
+ stats.row_count[ir] = ColumnStat[int](None)
640
+
641
+ # Account for the n_rows argument.
642
+ if ir.n_rows != -1:
643
+ if (metadata_value := stats.row_count[ir].value) is not None:
644
+ stats.row_count[ir] = ColumnStat[int](min(metadata_value, ir.n_rows))
645
+ else:
646
+ stats.row_count[ir] = ColumnStat[int](ir.n_rows)
647
+
648
+ # Update unique-count estimates with estimated and/or sampled statistics
649
+ for column_stats in stats.column_stats[ir].values():
650
+ if column_stats.source_info.implied_unique_count.value is None:
651
+ # We don't have a unique-count estimate, so we need to sample the data.
652
+ source_unique_stats = column_stats.source_info.unique_stats(force=False)
653
+ if source_unique_stats.count.value is not None:
654
+ column_stats.unique_count = source_unique_stats.count
655
+ elif (
656
+ unique_fraction := source_unique_stats.fraction.value
657
+ ) is not None and (row_count := stats.row_count[ir].value) is not None:
658
+ column_stats.unique_count = ColumnStat[int](
659
+ max(1, int(unique_fraction * row_count))
660
+ )
661
+ else:
662
+ column_stats.unique_count = column_stats.source_info.implied_unique_count
663
+
664
+ if ir.predicate is not None and ir.n_rows == -1:
665
+ apply_predicate_selectivity(ir, stats, ir.predicate.value, config_options)
666
+
667
+
668
+ @update_column_stats.register(Select)
669
+ def _(ir: Select, stats: StatsCollector, config_options: ConfigOptions) -> None:
670
+ # Update statistics for a Select node.
671
+
672
+ # Start by copying the child unique-count estimates.
673
+ copy_child_unique_counts(stats.column_stats[ir])
674
+
675
+ # Now update the row-count estimate.
676
+ (child,) = ir.children
677
+ child_row_count = stats.row_count.get(child, ColumnStat[int](None)).value
678
+ row_count_estimates: list[int | None] = []
679
+ for ne in ir.exprs:
680
+ child_column_stats = stats.column_stats[ir][ne.name].children
681
+ if isinstance(ne.value, Agg) and ne.value.name in _SUPPORTED_AGGS:
682
+ # This aggregation outputs a single row.
683
+ row_count_estimates.append(1)
684
+ stats.column_stats[ir][ne.name].unique_count = ColumnStat[int](
685
+ value=1, exact=True
686
+ )
687
+ elif (
688
+ len(child_column_stats) == 1
689
+ and any(
690
+ isinstance(expr, UnaryFunction) and expr.name == "unique"
691
+ for expr in traversal([ne.value])
692
+ )
693
+ and (value := child_column_stats[0].unique_count.value) is not None
694
+ ):
695
+ # We are doing a Select(unique) operation.
696
+ row_count_estimates.append(value)
697
+ else:
698
+ # Fallback case - use the child row-count estimate.
699
+ row_count_estimates.append(child_row_count)
700
+
701
+ stats.row_count[ir] = ColumnStat[int](
702
+ max((rc for rc in row_count_estimates if rc is not None), default=None),
703
+ )
704
+
705
+
706
+ @update_column_stats.register(Distinct)
707
+ @update_column_stats.register(GroupBy)
708
+ def _(
709
+ ir: Distinct | GroupBy, stats: StatsCollector, config_options: ConfigOptions
710
+ ) -> None:
711
+ # Update statistics for a Distinct or GroupBy node.
712
+ (child,) = ir.children
713
+ child_column_stats = stats.column_stats[child]
714
+ child_row_count = stats.row_count[child].value
715
+ key_names = (
716
+ list(ir.subset or ir.schema)
717
+ if isinstance(ir, Distinct)
718
+ else [n.name for n in ir.keys]
719
+ )
720
+ unique_counts = [
721
+ # k will be missing from child_column_stats if it's a literal
722
+ child_column_stats.get(k, ColumnStats(name=k)).unique_count.value
723
+ for k in key_names
724
+ ]
725
+ known_unique_count = sum(c for c in unique_counts if c is not None)
726
+ unknown_unique_count = sum(c is None for c in unique_counts)
727
+ if unknown_unique_count > 0:
728
+ # Use the child row-count to be conservative.
729
+ # TODO: Should we use a different heuristic here? For example,
730
+ # we could assume each unknown key introduces a factor of 3.
731
+ stats.row_count[ir] = ColumnStat[int](child_row_count)
732
+ else:
733
+ unique_count = known_unique_count
734
+ if child_row_count is not None:
735
+ # Don't allow the unique-count to exceed the child row-count.
736
+ unique_count = min(child_row_count, unique_count)
737
+ stats.row_count[ir] = ColumnStat[int](unique_count)
738
+
739
+ copy_child_unique_counts(stats.column_stats[ir])
740
+
741
+
742
+ @update_column_stats.register(Join)
743
+ def _(ir: Join, stats: StatsCollector, config_options: ConfigOptions) -> None:
744
+ # Apply basic join-cardinality estimation.
745
+ child_row_counts = known_child_row_counts(ir, stats)
746
+ if len(child_row_counts) == 2:
747
+ # Both children have row-count estimates.
748
+
749
+ # Use the PK-FK unique-count estimate for the join key.
750
+ # Otherwise, use the maximum unique-count estimate from the children.
751
+ unique_count_estimate = max(
752
+ # Join-based estimate (higher priority).
753
+ [
754
+ u.implied_unique_count
755
+ for u in stats.join_info.join_map.get(ir, [])
756
+ if u.implied_unique_count is not None
757
+ ],
758
+ default=None,
759
+ )
760
+ # TODO: Use local unique-count statistics if the implied unique-count
761
+ # estimates are missing. This never happens for now, but it will happen
762
+ # if/when we add a config option to disable PK-FK heuristics.
763
+
764
+ # Calculate the output row-count estimate.
765
+ left_rows, right_rows = child_row_counts
766
+ if unique_count_estimate is not None:
767
+ stats.row_count[ir] = ColumnStat[int](
768
+ max(1, (left_rows * right_rows) // unique_count_estimate)
769
+ )
770
+ else: # pragma: no cover; We always have a unique-count estimate (for now).
771
+ stats.row_count[ir] = ColumnStat[int](max((1, left_rows, right_rows)))
772
+ else:
773
+ # One or more children have an unknown row-count estimate.
774
+ stats.row_count[ir] = ColumnStat[int](None)
775
+
776
+ copy_child_unique_counts(stats.column_stats[ir])
777
+
778
+
779
+ @update_column_stats.register(Union)
780
+ def _(ir: Union, stats: StatsCollector, config_options: ConfigOptions) -> None:
781
+ # Add up child row-count estimates.
782
+ row_counts = known_child_row_counts(ir, stats)
783
+ stats.row_count[ir] = ColumnStat[int](sum(row_counts) or None)
784
+ # Add up unique counts (NOTE: This is probably very conservative).
785
+ for column_stats in stats.column_stats.get(ir, {}).values():
786
+ column_stats.unique_count = ColumnStat[int](
787
+ sum(
788
+ (
789
+ cs.unique_count.value
790
+ for cs in column_stats.children
791
+ if cs.unique_count.value is not None
792
+ ),
793
+ )
794
+ or None
795
+ )