cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. cudf_polars/GIT_COMMIT +1 -1
  2. cudf_polars/VERSION +1 -1
  3. cudf_polars/callback.py +60 -15
  4. cudf_polars/containers/column.py +137 -77
  5. cudf_polars/containers/dataframe.py +123 -34
  6. cudf_polars/containers/datatype.py +134 -13
  7. cudf_polars/dsl/expr.py +0 -2
  8. cudf_polars/dsl/expressions/aggregation.py +80 -28
  9. cudf_polars/dsl/expressions/binaryop.py +34 -14
  10. cudf_polars/dsl/expressions/boolean.py +110 -37
  11. cudf_polars/dsl/expressions/datetime.py +59 -30
  12. cudf_polars/dsl/expressions/literal.py +11 -5
  13. cudf_polars/dsl/expressions/rolling.py +460 -119
  14. cudf_polars/dsl/expressions/selection.py +9 -8
  15. cudf_polars/dsl/expressions/slicing.py +1 -1
  16. cudf_polars/dsl/expressions/string.py +256 -114
  17. cudf_polars/dsl/expressions/struct.py +19 -7
  18. cudf_polars/dsl/expressions/ternary.py +33 -3
  19. cudf_polars/dsl/expressions/unary.py +126 -64
  20. cudf_polars/dsl/ir.py +1053 -350
  21. cudf_polars/dsl/to_ast.py +30 -13
  22. cudf_polars/dsl/tracing.py +194 -0
  23. cudf_polars/dsl/translate.py +307 -107
  24. cudf_polars/dsl/utils/aggregations.py +43 -30
  25. cudf_polars/dsl/utils/reshape.py +14 -2
  26. cudf_polars/dsl/utils/rolling.py +12 -8
  27. cudf_polars/dsl/utils/windows.py +35 -20
  28. cudf_polars/experimental/base.py +55 -2
  29. cudf_polars/experimental/benchmarks/pdsds.py +12 -126
  30. cudf_polars/experimental/benchmarks/pdsh.py +792 -2
  31. cudf_polars/experimental/benchmarks/utils.py +596 -39
  32. cudf_polars/experimental/dask_registers.py +47 -20
  33. cudf_polars/experimental/dispatch.py +9 -3
  34. cudf_polars/experimental/distinct.py +2 -0
  35. cudf_polars/experimental/explain.py +15 -2
  36. cudf_polars/experimental/expressions.py +30 -15
  37. cudf_polars/experimental/groupby.py +25 -4
  38. cudf_polars/experimental/io.py +156 -124
  39. cudf_polars/experimental/join.py +53 -23
  40. cudf_polars/experimental/parallel.py +68 -19
  41. cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
  42. cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
  43. cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
  44. cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
  45. cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
  46. cudf_polars/experimental/rapidsmpf/core.py +488 -0
  47. cudf_polars/experimental/rapidsmpf/dask.py +172 -0
  48. cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
  49. cudf_polars/experimental/rapidsmpf/io.py +696 -0
  50. cudf_polars/experimental/rapidsmpf/join.py +322 -0
  51. cudf_polars/experimental/rapidsmpf/lower.py +74 -0
  52. cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
  53. cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
  54. cudf_polars/experimental/rapidsmpf/union.py +115 -0
  55. cudf_polars/experimental/rapidsmpf/utils.py +374 -0
  56. cudf_polars/experimental/repartition.py +9 -2
  57. cudf_polars/experimental/select.py +177 -14
  58. cudf_polars/experimental/shuffle.py +46 -12
  59. cudf_polars/experimental/sort.py +100 -26
  60. cudf_polars/experimental/spilling.py +1 -1
  61. cudf_polars/experimental/statistics.py +24 -5
  62. cudf_polars/experimental/utils.py +25 -7
  63. cudf_polars/testing/asserts.py +13 -8
  64. cudf_polars/testing/io.py +2 -1
  65. cudf_polars/testing/plugin.py +93 -17
  66. cudf_polars/typing/__init__.py +86 -32
  67. cudf_polars/utils/config.py +473 -58
  68. cudf_polars/utils/cuda_stream.py +70 -0
  69. cudf_polars/utils/versions.py +5 -4
  70. cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
  71. cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
  72. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
  73. cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
  74. cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
  75. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
  76. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@
6
6
  from __future__ import annotations
7
7
 
8
8
  import itertools
9
+ from decimal import Decimal
9
10
  from functools import partial
10
11
  from typing import TYPE_CHECKING, Any
11
12
 
@@ -16,7 +17,7 @@ import pylibcudf as plc
16
17
  from cudf_polars.containers import DataType
17
18
  from cudf_polars.dsl import expr, ir
18
19
  from cudf_polars.dsl.expressions.base import ExecutionContext
19
- from cudf_polars.utils.versions import POLARS_VERSION_LT_1323
20
+ from cudf_polars.utils.versions import POLARS_VERSION_LT_134, POLARS_VERSION_LT_1323
20
21
 
21
22
  if TYPE_CHECKING:
22
23
  from collections.abc import Callable, Generator, Iterable, Sequence
@@ -45,6 +46,11 @@ def replace_nulls(col: expr.Expr, value: Any, *, is_top: bool) -> expr.Expr:
45
46
  """
46
47
  if not is_top:
47
48
  return col
49
+ if isinstance(value, int) and value == 0:
50
+ dtype = col.dtype.plc_type
51
+ value = (
52
+ Decimal(0).scaleb(dtype.scale()) if plc.traits.is_fixed_point(dtype) else 0
53
+ )
48
54
  return expr.UnaryFunction(
49
55
  col.dtype, "fill_null", (), col, expr.Literal(col.dtype, value)
50
56
  )
@@ -91,17 +97,25 @@ def decompose_single_agg(
91
97
  name = named_expr.name
92
98
  if isinstance(agg, expr.UnaryFunction) and agg.name in {
93
99
  "rank",
100
+ "fill_null_with_strategy",
101
+ "cum_sum",
94
102
  }:
95
103
  if context != ExecutionContext.WINDOW:
96
104
  raise NotImplementedError(
97
105
  f"{agg.name} is not supported in groupby or rolling context"
98
106
  )
107
+ if agg.name == "fill_null_with_strategy" and (
108
+ strategy := agg.options[0]
109
+ ) not in {"forward", "backward"}:
110
+ raise NotImplementedError(
111
+ f"fill_null({strategy=}) not supported in a groupy or rolling context"
112
+ )
99
113
  # Ensure Polars semantics for dtype:
100
114
  # - average -> Float64
101
115
  # - min/max/dense/ordinal -> IDX_DTYPE (UInt32/UInt64)
102
116
  post_col: expr.Expr = expr.Col(agg.dtype, name)
103
117
  if agg.name == "rank":
104
- post_col = expr.Cast(agg.dtype, post_col)
118
+ post_col = expr.Cast(agg.dtype, True, post_col) # noqa: FBT003
105
119
 
106
120
  return [(named_expr, True)], named_expr.reconstruct(post_col)
107
121
  if isinstance(agg, expr.UnaryFunction) and agg.name == "null_count":
@@ -117,10 +131,10 @@ def decompose_single_agg(
117
131
  sum_name = next(name_generator)
118
132
  sum_agg = expr.NamedExpr(
119
133
  sum_name,
120
- expr.Agg(u32, "sum", (), expr.Cast(u32, is_null_bool)),
134
+ expr.Agg(u32, "sum", (), context, expr.Cast(u32, True, is_null_bool)), # noqa: FBT003
121
135
  )
122
136
  return [(sum_agg, True)], named_expr.reconstruct(
123
- expr.Cast(u32, expr.Col(u32, sum_name))
137
+ expr.Cast(u32, True, expr.Col(u32, sum_name)) # noqa: FBT003
124
138
  )
125
139
  if isinstance(agg, expr.Col):
126
140
  # TODO: collect_list produces null for empty group in libcudf, empty list in polars.
@@ -146,15 +160,6 @@ def decompose_single_agg(
146
160
  return [(named_expr, True)], named_expr.reconstruct(expr.Col(agg.dtype, name))
147
161
  if isinstance(agg, (expr.Literal, expr.LiteralColumn)):
148
162
  return [], named_expr
149
- if (
150
- is_top
151
- and isinstance(agg, expr.UnaryFunction)
152
- and agg.name == "fill_null_with_strategy"
153
- ):
154
- strategy, _ = agg.options
155
- raise NotImplementedError(
156
- f"fill_null_with_strategy({strategy!r}) is not supported in groupby aggregations"
157
- )
158
163
  if isinstance(agg, expr.Agg):
159
164
  if agg.name == "quantile":
160
165
  # Second child the requested quantile (which is asserted
@@ -163,7 +168,7 @@ def decompose_single_agg(
163
168
  else:
164
169
  (child,) = agg.children
165
170
  needs_masking = agg.name in {"min", "max"} and plc.traits.is_floating_point(
166
- child.dtype.plc
171
+ child.dtype.plc_type
167
172
  )
168
173
  if needs_masking and agg.options:
169
174
  # pl.col("a").nan_max or nan_min
@@ -177,7 +182,7 @@ def decompose_single_agg(
177
182
  if any(has_agg for _, has_agg in aggs):
178
183
  raise NotImplementedError("Nested aggs in groupby not supported")
179
184
 
180
- child_dtype = child.dtype.plc
185
+ child_dtype = child.dtype.plc_type
181
186
  req = agg.agg_request
182
187
  is_median = agg.name == "median"
183
188
  is_quantile = agg.name == "quantile"
@@ -186,18 +191,22 @@ def decompose_single_agg(
186
191
  # mean/median on decimal: Polars returns float -> pre-cast
187
192
  decimal_unsupported = False
188
193
  if plc.traits.is_fixed_point(child_dtype):
189
- if is_quantile:
194
+ cast_for_quantile = is_quantile and not POLARS_VERSION_LT_134
195
+ cast_for_mean_or_median = (
196
+ agg.name in {"mean", "median"}
197
+ ) and plc.traits.is_floating_point(agg.dtype.plc_type)
198
+
199
+ if cast_for_quantile or cast_for_mean_or_median:
200
+ child = expr.Cast(
201
+ agg.dtype
202
+ if plc.traits.is_floating_point(agg.dtype.plc_type)
203
+ else DataType(pl.Float64()),
204
+ True, # noqa: FBT003
205
+ child,
206
+ )
207
+ child_dtype = child.dtype.plc_type
208
+ elif is_quantile and POLARS_VERSION_LT_134: # pragma: no cover
190
209
  decimal_unsupported = True
191
- elif agg.name in {"mean", "median"}:
192
- tid = agg.dtype.plc.id()
193
- if tid in {plc.TypeId.FLOAT32, plc.TypeId.FLOAT64}:
194
- cast_to = (
195
- DataType(pl.Float64)
196
- if tid == plc.TypeId.FLOAT64
197
- else DataType(pl.Float32)
198
- )
199
- child = expr.Cast(cast_to, child)
200
- child_dtype = child.dtype.plc
201
210
 
202
211
  is_group_quantile_supported = plc.traits.is_integral(
203
212
  child_dtype
@@ -221,9 +230,13 @@ def decompose_single_agg(
221
230
 
222
231
  if agg.name == "sum":
223
232
  col = (
224
- expr.Cast(agg.dtype, expr.Col(DataType(pl.datatypes.Int64()), name))
233
+ expr.Cast(
234
+ agg.dtype,
235
+ True, # noqa: FBT003
236
+ expr.Col(DataType(pl.datatypes.Int64()), name),
237
+ )
225
238
  if (
226
- plc.traits.is_integral(agg.dtype.plc)
239
+ plc.traits.is_integral(agg.dtype.plc_type)
227
240
  and agg.dtype.id() != plc.TypeId.INT64
228
241
  )
229
242
  else expr.Col(agg.dtype, name)
@@ -272,9 +285,9 @@ def decompose_single_agg(
272
285
  post_agg_col: expr.Expr = expr.Col(
273
286
  DataType(pl.Float64()), name
274
287
  ) # libcudf promotes to float64
275
- if agg.dtype.plc.id() == plc.TypeId.FLOAT32:
288
+ if agg.dtype.plc_type.id() == plc.TypeId.FLOAT32:
276
289
  # Cast back to float32 to match Polars
277
- post_agg_col = expr.Cast(agg.dtype, post_agg_col)
290
+ post_agg_col = expr.Cast(agg.dtype, True, post_agg_col) # noqa: FBT003
278
291
  return [(named_expr, True)], named_expr.reconstruct(post_agg_col)
279
292
  else:
280
293
  return [(named_expr, True)], named_expr.reconstruct(
@@ -4,12 +4,19 @@
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
+ from typing import TYPE_CHECKING
8
+
7
9
  import pylibcudf as plc
8
10
 
9
11
  from cudf_polars.containers import Column
10
12
 
13
+ if TYPE_CHECKING:
14
+ from rmm.pylibrmm.stream import Stream
15
+
11
16
 
12
- def broadcast(*columns: Column, target_length: int | None = None) -> list[Column]:
17
+ def broadcast(
18
+ *columns: Column, target_length: int | None = None, stream: Stream
19
+ ) -> list[Column]:
13
20
  """
14
21
  Broadcast a sequence of columns to a common length.
15
22
 
@@ -20,6 +27,9 @@ def broadcast(*columns: Column, target_length: int | None = None) -> list[Column
20
27
  target_length
21
28
  Optional length to broadcast to. If not provided, uses the
22
29
  non-unit length of existing columns.
30
+ stream
31
+ CUDA stream used for device memory operations and kernel launches
32
+ on this dataframe.
23
33
 
24
34
  Returns
25
35
  -------
@@ -63,7 +73,9 @@ def broadcast(*columns: Column, target_length: int | None = None) -> list[Column
63
73
  column
64
74
  if column.size != 1
65
75
  else Column(
66
- plc.Column.from_scalar(column.obj_scalar, nrows),
76
+ plc.Column.from_scalar(
77
+ column.obj_scalar(stream=stream), nrows, stream=stream
78
+ ),
67
79
  is_sorted=plc.types.Sorted.YES,
68
80
  order=plc.types.Order.ASCENDING,
69
81
  null_order=plc.types.NullOrder.BEFORE,
@@ -13,7 +13,7 @@ from cudf_polars.dsl import expr, ir
13
13
  from cudf_polars.dsl.expressions.base import ExecutionContext
14
14
  from cudf_polars.dsl.utils.aggregations import apply_pre_evaluation
15
15
  from cudf_polars.dsl.utils.naming import unique_names
16
- from cudf_polars.dsl.utils.windows import offsets_to_windows
16
+ from cudf_polars.dsl.utils.windows import duration_to_int
17
17
 
18
18
  if TYPE_CHECKING:
19
19
  from collections.abc import Sequence
@@ -74,10 +74,13 @@ def rewrite_rolling(
74
74
  index_name = options.rolling.index_column
75
75
  index_dtype = schema[index_name]
76
76
  index_col = expr.Col(index_dtype, index_name)
77
- if plc.traits.is_integral(index_dtype.plc) and index_dtype.id() != plc.TypeId.INT64:
77
+ if (
78
+ plc.traits.is_integral(index_dtype.plc_type)
79
+ and index_dtype.id() != plc.TypeId.INT64
80
+ ):
78
81
  plc_index_dtype = plc.DataType(plc.TypeId.INT64)
79
82
  else:
80
- plc_index_dtype = index_dtype.plc
83
+ plc_index_dtype = index_dtype.plc_type
81
84
  index = expr.NamedExpr(index_name, index_col)
82
85
  temp_prefix = "_" * max(map(len, schema))
83
86
  if len(aggs) > 0:
@@ -92,9 +95,9 @@ def rewrite_rolling(
92
95
  else:
93
96
  rolling_schema = schema
94
97
  apply_post_evaluation = lambda inp: inp # noqa: E731
95
- preceding, following = offsets_to_windows(
96
- plc_index_dtype, options.rolling.offset, options.rolling.period
97
- )
98
+ preceding_ordinal = duration_to_int(plc_index_dtype, *options.rolling.offset)
99
+ following_ordinal = duration_to_int(plc_index_dtype, *options.rolling.period)
100
+
98
101
  if (n := len(keys)) > 0:
99
102
  # Grouped rolling in polars sorts the output by the groups.
100
103
  inp = ir.Sort(
@@ -110,8 +113,9 @@ def rewrite_rolling(
110
113
  ir.Rolling(
111
114
  rolling_schema,
112
115
  index,
113
- preceding,
114
- following,
116
+ plc_index_dtype,
117
+ preceding_ordinal,
118
+ following_ordinal,
115
119
  options.rolling.closed_window,
116
120
  keys,
117
121
  aggs,
@@ -12,7 +12,9 @@ import polars as pl
12
12
  import pylibcudf as plc
13
13
 
14
14
  if TYPE_CHECKING:
15
- from cudf_polars.typing import ClosedInterval, Duration
15
+ from rmm.pylibrmm.stream import Stream
16
+
17
+ from cudf_polars.typing import ClosedInterval
16
18
 
17
19
 
18
20
  __all__ = [
@@ -75,7 +77,7 @@ def duration_to_int(
75
77
  return -value if negative else value
76
78
 
77
79
 
78
- def duration_to_scalar(dtype: plc.DataType, value: int) -> plc.Scalar:
80
+ def duration_to_scalar(dtype: plc.DataType, value: int, stream: Stream) -> plc.Scalar:
79
81
  """
80
82
  Convert a raw polars duration value to a pylibcudf scalar.
81
83
 
@@ -86,6 +88,9 @@ def duration_to_scalar(dtype: plc.DataType, value: int) -> plc.Scalar:
86
88
  value
87
89
  The raw value as in integer. If `dtype` represents a timestamp
88
90
  type, this should be in nanoseconds.
91
+ stream
92
+ CUDA stream used for device memory operations and kernel launches
93
+ on this dataframe. The returned scalar will be valid on this stream.
89
94
 
90
95
  Returns
91
96
  -------
@@ -99,20 +104,28 @@ def duration_to_scalar(dtype: plc.DataType, value: int) -> plc.Scalar:
99
104
  """
100
105
  tid = dtype.id()
101
106
  if tid == plc.TypeId.INT64:
102
- return plc.Scalar.from_py(value, dtype)
107
+ return plc.Scalar.from_py(value, dtype, stream=stream)
103
108
  elif tid == plc.TypeId.TIMESTAMP_NANOSECONDS:
104
- return plc.Scalar.from_py(value, plc.DataType(plc.TypeId.DURATION_NANOSECONDS))
109
+ return plc.Scalar.from_py(
110
+ value, plc.DataType(plc.TypeId.DURATION_NANOSECONDS), stream=stream
111
+ )
105
112
  elif tid == plc.TypeId.TIMESTAMP_MICROSECONDS:
106
113
  return plc.Scalar.from_py(
107
- value // 1000, plc.DataType(plc.TypeId.DURATION_MICROSECONDS)
114
+ value // 1000,
115
+ plc.DataType(plc.TypeId.DURATION_MICROSECONDS),
116
+ stream=stream,
108
117
  )
109
118
  elif tid == plc.TypeId.TIMESTAMP_MILLISECONDS:
110
119
  return plc.Scalar.from_py(
111
- value // 1_000_000, plc.DataType(plc.TypeId.DURATION_MILLISECONDS)
120
+ value // 1_000_000,
121
+ plc.DataType(plc.TypeId.DURATION_MILLISECONDS),
122
+ stream=stream,
112
123
  )
113
124
  elif tid == plc.TypeId.TIMESTAMP_DAYS:
114
125
  return plc.Scalar.from_py(
115
- value // 86_400_000_000_000, plc.DataType(plc.TypeId.DURATION_DAYS)
126
+ value // 86_400_000_000_000,
127
+ plc.DataType(plc.TypeId.DURATION_DAYS),
128
+ stream=stream,
116
129
  )
117
130
  else:
118
131
  raise NotImplementedError(
@@ -122,8 +135,9 @@ def duration_to_scalar(dtype: plc.DataType, value: int) -> plc.Scalar:
122
135
 
123
136
  def offsets_to_windows(
124
137
  dtype: plc.DataType,
125
- offset: Duration,
126
- period: Duration,
138
+ offset_i: int,
139
+ period_i: int,
140
+ stream: Stream,
127
141
  ) -> tuple[plc.Scalar, plc.Scalar]:
128
142
  """
129
143
  Convert polars offset/period pair to preceding/following windows.
@@ -132,21 +146,22 @@ def offsets_to_windows(
132
146
  ----------
133
147
  dtype
134
148
  Datatype of column defining windows
135
- offset
136
- Offset duration
137
- period
138
- Period of window
149
+ offset_i
150
+ Integer ordinal representing the offset of the window.
151
+ See :func:`duration_to_int` for more details.
152
+ period_i
153
+ Integer ordinal representing the period of the window.
154
+ See :func:`duration_to_int` for more details.
155
+ stream
156
+ CUDA stream used for device memory operations and kernel launches
139
157
 
140
158
  Returns
141
159
  -------
142
- tuple of preceding and following windows as pylibcudf scalars.
160
+ tuple of preceding and following windows as host integers.
143
161
  """
144
- offset_i = duration_to_int(dtype, *offset)
145
- period_i = duration_to_int(dtype, *period)
146
- # Polars uses current_row + offset, ..., current_row + offset + period
147
- # Libcudf uses current_row - preceding, ..., current_row + following
148
- return duration_to_scalar(dtype, -offset_i), duration_to_scalar(
149
- dtype, offset_i + period_i
162
+ return (
163
+ duration_to_scalar(dtype, -offset_i, stream=stream),
164
+ duration_to_scalar(dtype, offset_i + period_i, stream=stream),
150
165
  )
151
166
 
152
167
 
@@ -5,7 +5,9 @@
5
5
  from __future__ import annotations
6
6
 
7
7
  import dataclasses
8
+ import enum
8
9
  from collections import defaultdict
10
+ from enum import IntEnum
9
11
  from functools import cached_property
10
12
  from typing import TYPE_CHECKING, Any, Generic, NamedTuple, TypeVar
11
13
 
@@ -20,19 +22,24 @@ if TYPE_CHECKING:
20
22
  class PartitionInfo:
21
23
  """Partitioning information."""
22
24
 
23
- __slots__ = ("count", "partitioned_on")
25
+ __slots__ = ("count", "io_plan", "partitioned_on")
24
26
  count: int
25
27
  """Partition count."""
26
28
  partitioned_on: tuple[NamedExpr, ...]
27
29
  """Columns the data is hash-partitioned on."""
30
+ io_plan: IOPartitionPlan | None
31
+ """IO partitioning plan (Scan nodes only)."""
28
32
 
29
33
  def __init__(
30
34
  self,
31
35
  count: int,
36
+ *,
32
37
  partitioned_on: tuple[NamedExpr, ...] = (),
38
+ io_plan: IOPartitionPlan | None = None,
33
39
  ):
34
40
  self.count = count
35
41
  self.partitioned_on = partitioned_on
42
+ self.io_plan = io_plan
36
43
 
37
44
  def keys(self, node: Node) -> Iterator[tuple[str, int]]:
38
45
  """Return the partitioned keys for a given node."""
@@ -108,13 +115,17 @@ class DataSourceInfo:
108
115
  """
109
116
 
110
117
  _unique_stats_columns: set[str]
118
+ _read_columns: set[str]
111
119
 
112
120
  @property
113
121
  def row_count(self) -> ColumnStat[int]: # pragma: no cover
114
122
  """Data source row-count estimate."""
115
123
  raise NotImplementedError("Sub-class must implement row_count.")
116
124
 
117
- def unique_stats(self, column: str) -> UniqueStats: # pragma: no cover
125
+ def unique_stats(
126
+ self,
127
+ column: str,
128
+ ) -> UniqueStats: # pragma: no cover
118
129
  """Return unique-value statistics for a column."""
119
130
  raise NotImplementedError("Sub-class must implement unique_stats.")
120
131
 
@@ -131,6 +142,10 @@ class DataSourceInfo:
131
142
  """Add a column needing unique-value information."""
132
143
  self._unique_stats_columns.add(column)
133
144
 
145
+ def add_read_column(self, column: str) -> None:
146
+ """Add a column needing to be read."""
147
+ self._read_columns.add(column)
148
+
134
149
 
135
150
  class DataSourcePair(NamedTuple):
136
151
  """Pair of table-source and column-name information."""
@@ -230,6 +245,11 @@ class ColumnSourceInfo:
230
245
  for table_source, column_name in self.table_source_pairs:
231
246
  table_source.add_unique_stats_column(column or column_name)
232
247
 
248
+ def add_read_column(self, column: str | None = None) -> None:
249
+ """Add a column needing to be read."""
250
+ for table_source, column_name in self.table_source_pairs:
251
+ table_source.add_read_column(column or column_name)
252
+
233
253
 
234
254
  class ColumnStats:
235
255
  """
@@ -384,3 +404,36 @@ class StatsCollector:
384
404
  self.row_count: dict[IR, ColumnStat[int]] = {}
385
405
  self.column_stats: dict[IR, dict[str, ColumnStats]] = {}
386
406
  self.join_info = JoinInfo()
407
+
408
+
409
+ class IOPartitionFlavor(IntEnum):
410
+ """Flavor of IO partitioning."""
411
+
412
+ SINGLE_FILE = enum.auto() # 1:1 mapping between files and partitions
413
+ SPLIT_FILES = enum.auto() # Split each file into >1 partition
414
+ FUSED_FILES = enum.auto() # Fuse multiple files into each partition
415
+ SINGLE_READ = enum.auto() # One worker/task reads everything
416
+
417
+
418
+ class IOPartitionPlan:
419
+ """
420
+ IO partitioning plan.
421
+
422
+ Notes
423
+ -----
424
+ The meaning of `factor` depends on the value of `flavor`:
425
+ - SINGLE_FILE: `factor` must be `1`.
426
+ - SPLIT_FILES: `factor` is the number of partitions per file.
427
+ - FUSED_FILES: `factor` is the number of files per partition.
428
+ - SINGLE_READ: `factor` is the total number of files.
429
+ """
430
+
431
+ __slots__ = ("factor", "flavor")
432
+ factor: int
433
+ flavor: IOPartitionFlavor
434
+
435
+ def __init__(self, factor: int, flavor: IOPartitionFlavor) -> None:
436
+ if flavor == IOPartitionFlavor.SINGLE_FILE and factor != 1: # pragma: no cover
437
+ raise ValueError(f"Expected factor == 1 for {flavor}, got: {factor}")
438
+ self.factor = factor
439
+ self.flavor = flavor
@@ -16,26 +16,17 @@ from __future__ import annotations
16
16
  import contextlib
17
17
  import importlib
18
18
  import os
19
- import time
20
- from collections import defaultdict
21
- from pathlib import Path
22
19
  from typing import TYPE_CHECKING
23
20
 
24
- import polars as pl
25
-
26
21
  with contextlib.suppress(ImportError):
27
22
  from cudf_polars.experimental.benchmarks.utils import (
28
- Record,
29
- RunConfig,
30
- get_executor_options,
31
- parse_args,
23
+ run_duckdb,
32
24
  run_polars,
25
+ run_validate,
33
26
  )
34
27
 
35
28
  if TYPE_CHECKING:
36
- from collections.abc import Sequence
37
29
  from types import ModuleType
38
- from typing import Any
39
30
 
40
31
  # Without this setting, the first IO task to run
41
32
  # on each worker takes ~15 sec extra
@@ -58,7 +49,7 @@ def valid_query(name: str) -> bool:
58
49
  class PDSDSQueriesMeta(type):
59
50
  """Metaclass used for query lookup."""
60
51
 
61
- def __getattr__(cls, name: str): # type: ignore
52
+ def __getattr__(cls, name: str): # type: ignore[no-untyped-def]
62
53
  """Query lookup."""
63
54
  if valid_query(name):
64
55
  q_num = int(name[1:])
@@ -88,118 +79,6 @@ class PDSDSDuckDBQueries(PDSDSQueries):
88
79
  q_impl = "duckdb_impl"
89
80
 
90
81
 
91
- def execute_duckdb_query(query: str, dataset_path: Path) -> pl.DataFrame:
92
- """Execute a query with DuckDB."""
93
- import duckdb
94
-
95
- conn = duckdb.connect()
96
-
97
- statements = [
98
- f"CREATE VIEW {table.stem} as SELECT * FROM read_parquet('{table.absolute()}');"
99
- for table in Path(dataset_path).glob("*.parquet")
100
- ]
101
- statements.append(query)
102
- return conn.execute("\n".join(statements)).pl()
103
-
104
-
105
- def run_duckdb(benchmark: Any, options: Sequence[str] | None = None) -> None:
106
- """Run the benchmark with DuckDB."""
107
- args = parse_args(options, num_queries=99)
108
- vars(args).update({"query_set": benchmark.name})
109
- run_config = RunConfig.from_args(args)
110
- records: defaultdict[int, list[Record]] = defaultdict(list)
111
-
112
- for q_id in run_config.queries:
113
- try:
114
- duckdb_query = getattr(PDSDSDuckDBQueries, f"q{q_id}")(run_config)
115
- except AttributeError as err:
116
- raise NotImplementedError(f"Query {q_id} not implemented.") from err
117
-
118
- print(f"DuckDB Executing: {q_id}")
119
- records[q_id] = []
120
-
121
- for i in range(args.iterations):
122
- t0 = time.time()
123
-
124
- result = execute_duckdb_query(duckdb_query, run_config.dataset_path)
125
-
126
- t1 = time.time()
127
- record = Record(query=q_id, duration=t1 - t0)
128
- if args.print_results:
129
- print(result)
130
-
131
- print(f"Query {q_id} - Iteration {i} finished in {record.duration:0.4f}s")
132
- records[q_id].append(record)
133
-
134
-
135
- def run_validate(benchmark: Any, options: Sequence[str] | None = None) -> None:
136
- """Validate Polars CPU vs DuckDB or Polars GPU."""
137
- from polars.testing import assert_frame_equal
138
-
139
- args = parse_args(options, num_queries=99)
140
- vars(args).update({"query_set": benchmark.name})
141
- run_config = RunConfig.from_args(args)
142
-
143
- baseline = args.baseline
144
- if baseline not in {"duckdb", "cpu"}:
145
- raise ValueError("Baseline must be one of: 'duckdb', 'cpu'")
146
-
147
- failures: list[int] = []
148
-
149
- engine: pl.GPUEngine | None = None
150
- if run_config.executor != "cpu":
151
- engine = pl.GPUEngine(
152
- raise_on_fail=True,
153
- executor=run_config.executor,
154
- executor_options=get_executor_options(run_config, PDSDSPolarsQueries),
155
- )
156
-
157
- for q_id in run_config.queries:
158
- print(f"\nValidating Query {q_id}")
159
- try:
160
- polars_query = getattr(PDSDSPolarsQueries, f"q{q_id}")(run_config)
161
- duckdb_query = getattr(PDSDSDuckDBQueries, f"q{q_id}")(run_config)
162
- except AttributeError as err:
163
- raise NotImplementedError(f"Query {q_id} not implemented.") from err
164
-
165
- if baseline == "duckdb":
166
- base_result = execute_duckdb_query(duckdb_query, run_config.dataset_path)
167
- elif baseline == "cpu":
168
- base_result = polars_query.collect(new_streaming=True)
169
-
170
- if run_config.executor == "cpu":
171
- test_result = polars_query.collect(new_streaming=True)
172
- else:
173
- try:
174
- test_result = polars_query.collect(engine=engine)
175
- except Exception as e:
176
- failures.append(q_id)
177
- print(f"❌ Query {q_id} failed validation: GPU execution failed.\n{e}")
178
- continue
179
-
180
- try:
181
- assert_frame_equal(
182
- base_result,
183
- test_result,
184
- check_dtypes=True,
185
- check_column_order=False,
186
- )
187
- print(f"✅ Query {q_id} passed validation.")
188
- except AssertionError as e:
189
- failures.append(q_id)
190
- print(f"❌ Query {q_id} failed validation:\n{e}")
191
- if args.print_results:
192
- print("Baseline Result:\n", base_result)
193
- print("Test Result:\n", test_result)
194
-
195
- if failures:
196
- print("\nValidation Summary:")
197
- print("===================")
198
- print(f"{len(failures)} query(s) failed: {failures}")
199
- else:
200
- print("\nAll queries passed validation.")
201
-
202
-
203
82
  if __name__ == "__main__":
204
83
  import argparse
205
84
 
@@ -215,6 +94,13 @@ if __name__ == "__main__":
215
94
  if args.engine == "polars":
216
95
  run_polars(PDSDSPolarsQueries, extra_args, num_queries=99)
217
96
  elif args.engine == "duckdb":
218
- run_duckdb(PDSDSDuckDBQueries, extra_args)
97
+ run_duckdb(PDSDSDuckDBQueries, extra_args, num_queries=99)
219
98
  elif args.engine == "validate":
220
- run_validate(PDSDSQueries, extra_args)
99
+ run_validate(
100
+ PDSDSPolarsQueries,
101
+ PDSDSDuckDBQueries,
102
+ extra_args,
103
+ num_queries=99,
104
+ check_dtypes=True,
105
+ check_column_order=True,
106
+ )