cudf-polars-cu13 25.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. cudf_polars/GIT_COMMIT +1 -0
  2. cudf_polars/VERSION +1 -0
  3. cudf_polars/__init__.py +28 -0
  4. cudf_polars/_version.py +21 -0
  5. cudf_polars/callback.py +318 -0
  6. cudf_polars/containers/__init__.py +13 -0
  7. cudf_polars/containers/column.py +495 -0
  8. cudf_polars/containers/dataframe.py +361 -0
  9. cudf_polars/containers/datatype.py +137 -0
  10. cudf_polars/dsl/__init__.py +8 -0
  11. cudf_polars/dsl/expr.py +66 -0
  12. cudf_polars/dsl/expressions/__init__.py +8 -0
  13. cudf_polars/dsl/expressions/aggregation.py +226 -0
  14. cudf_polars/dsl/expressions/base.py +272 -0
  15. cudf_polars/dsl/expressions/binaryop.py +120 -0
  16. cudf_polars/dsl/expressions/boolean.py +326 -0
  17. cudf_polars/dsl/expressions/datetime.py +271 -0
  18. cudf_polars/dsl/expressions/literal.py +97 -0
  19. cudf_polars/dsl/expressions/rolling.py +643 -0
  20. cudf_polars/dsl/expressions/selection.py +74 -0
  21. cudf_polars/dsl/expressions/slicing.py +46 -0
  22. cudf_polars/dsl/expressions/sorting.py +85 -0
  23. cudf_polars/dsl/expressions/string.py +1002 -0
  24. cudf_polars/dsl/expressions/struct.py +137 -0
  25. cudf_polars/dsl/expressions/ternary.py +49 -0
  26. cudf_polars/dsl/expressions/unary.py +517 -0
  27. cudf_polars/dsl/ir.py +2607 -0
  28. cudf_polars/dsl/nodebase.py +164 -0
  29. cudf_polars/dsl/to_ast.py +359 -0
  30. cudf_polars/dsl/tracing.py +16 -0
  31. cudf_polars/dsl/translate.py +939 -0
  32. cudf_polars/dsl/traversal.py +224 -0
  33. cudf_polars/dsl/utils/__init__.py +8 -0
  34. cudf_polars/dsl/utils/aggregations.py +481 -0
  35. cudf_polars/dsl/utils/groupby.py +98 -0
  36. cudf_polars/dsl/utils/naming.py +34 -0
  37. cudf_polars/dsl/utils/replace.py +61 -0
  38. cudf_polars/dsl/utils/reshape.py +74 -0
  39. cudf_polars/dsl/utils/rolling.py +121 -0
  40. cudf_polars/dsl/utils/windows.py +192 -0
  41. cudf_polars/experimental/__init__.py +8 -0
  42. cudf_polars/experimental/base.py +386 -0
  43. cudf_polars/experimental/benchmarks/__init__.py +4 -0
  44. cudf_polars/experimental/benchmarks/pdsds.py +220 -0
  45. cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
  46. cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
  47. cudf_polars/experimental/benchmarks/pdsds_queries/q10.py +225 -0
  48. cudf_polars/experimental/benchmarks/pdsds_queries/q2.py +244 -0
  49. cudf_polars/experimental/benchmarks/pdsds_queries/q3.py +65 -0
  50. cudf_polars/experimental/benchmarks/pdsds_queries/q4.py +359 -0
  51. cudf_polars/experimental/benchmarks/pdsds_queries/q5.py +462 -0
  52. cudf_polars/experimental/benchmarks/pdsds_queries/q6.py +92 -0
  53. cudf_polars/experimental/benchmarks/pdsds_queries/q7.py +79 -0
  54. cudf_polars/experimental/benchmarks/pdsds_queries/q8.py +524 -0
  55. cudf_polars/experimental/benchmarks/pdsds_queries/q9.py +137 -0
  56. cudf_polars/experimental/benchmarks/pdsh.py +814 -0
  57. cudf_polars/experimental/benchmarks/utils.py +832 -0
  58. cudf_polars/experimental/dask_registers.py +200 -0
  59. cudf_polars/experimental/dispatch.py +156 -0
  60. cudf_polars/experimental/distinct.py +197 -0
  61. cudf_polars/experimental/explain.py +157 -0
  62. cudf_polars/experimental/expressions.py +590 -0
  63. cudf_polars/experimental/groupby.py +327 -0
  64. cudf_polars/experimental/io.py +943 -0
  65. cudf_polars/experimental/join.py +391 -0
  66. cudf_polars/experimental/parallel.py +423 -0
  67. cudf_polars/experimental/repartition.py +69 -0
  68. cudf_polars/experimental/scheduler.py +155 -0
  69. cudf_polars/experimental/select.py +188 -0
  70. cudf_polars/experimental/shuffle.py +354 -0
  71. cudf_polars/experimental/sort.py +609 -0
  72. cudf_polars/experimental/spilling.py +151 -0
  73. cudf_polars/experimental/statistics.py +795 -0
  74. cudf_polars/experimental/utils.py +169 -0
  75. cudf_polars/py.typed +0 -0
  76. cudf_polars/testing/__init__.py +8 -0
  77. cudf_polars/testing/asserts.py +448 -0
  78. cudf_polars/testing/io.py +122 -0
  79. cudf_polars/testing/plugin.py +236 -0
  80. cudf_polars/typing/__init__.py +219 -0
  81. cudf_polars/utils/__init__.py +8 -0
  82. cudf_polars/utils/config.py +741 -0
  83. cudf_polars/utils/conversion.py +40 -0
  84. cudf_polars/utils/dtypes.py +118 -0
  85. cudf_polars/utils/sorting.py +53 -0
  86. cudf_polars/utils/timer.py +39 -0
  87. cudf_polars/utils/versions.py +27 -0
  88. cudf_polars_cu13-25.10.0.dist-info/METADATA +136 -0
  89. cudf_polars_cu13-25.10.0.dist-info/RECORD +92 -0
  90. cudf_polars_cu13-25.10.0.dist-info/WHEEL +5 -0
  91. cudf_polars_cu13-25.10.0.dist-info/licenses/LICENSE +201 -0
  92. cudf_polars_cu13-25.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,169 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Multi-partition utilities."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import operator
8
+ import warnings
9
+ from functools import reduce
10
+ from itertools import chain
11
+ from typing import TYPE_CHECKING
12
+
13
+ from cudf_polars.dsl.expr import Col, Expr, GroupedRollingWindow, UnaryFunction
14
+ from cudf_polars.dsl.ir import Union
15
+ from cudf_polars.dsl.traversal import traversal
16
+ from cudf_polars.experimental.base import ColumnStat, PartitionInfo
17
+
18
+ if TYPE_CHECKING:
19
+ from collections.abc import MutableMapping, Sequence
20
+
21
+ from cudf_polars.containers import DataFrame
22
+ from cudf_polars.dsl.expr import Expr
23
+ from cudf_polars.dsl.ir import IR
24
+ from cudf_polars.experimental.base import ColumnStats
25
+ from cudf_polars.experimental.dispatch import LowerIRTransformer
26
+ from cudf_polars.utils.config import ConfigOptions
27
+
28
+
29
+ def _concat(*dfs: DataFrame) -> DataFrame:
30
+ # Concatenate a sequence of DataFrames vertically
31
+ return Union.do_evaluate(None, *dfs)
32
+
33
+
34
+ def _fallback_inform(msg: str, config_options: ConfigOptions) -> None:
35
+ """Inform the user of single-partition fallback."""
36
+ # Satisfy type checking
37
+ assert config_options.executor.name == "streaming", (
38
+ "'in-memory' executor not supported in '_fallback_inform'"
39
+ )
40
+
41
+ match fallback_mode := config_options.executor.fallback_mode:
42
+ case "warn":
43
+ warnings.warn(msg, stacklevel=2)
44
+ case "raise":
45
+ raise NotImplementedError(msg)
46
+ case "silent":
47
+ pass
48
+ case _: # pragma: no cover; Should never get here.
49
+ raise ValueError(
50
+ f"{fallback_mode} is not a supported 'fallback_mode' "
51
+ "option. Please use 'warn', 'raise', or 'silent'."
52
+ )
53
+
54
+
55
+ def _lower_ir_fallback(
56
+ ir: IR,
57
+ rec: LowerIRTransformer,
58
+ *,
59
+ msg: str | None = None,
60
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
61
+ # Catch-all single-partition lowering logic.
62
+ # If any children contain multiple partitions,
63
+ # those children will be collapsed with `Repartition`.
64
+ from cudf_polars.experimental.repartition import Repartition
65
+
66
+ # Lower children
67
+ lowered_children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True)
68
+ partition_info = reduce(operator.or_, _partition_info)
69
+
70
+ # Ensure all children are single-partitioned
71
+ children = []
72
+ fallback = False
73
+ for c in lowered_children:
74
+ child = c
75
+ if partition_info[c].count > 1:
76
+ # Fall-back logic
77
+ fallback = True
78
+ child = Repartition(child.schema, child)
79
+ partition_info[child] = PartitionInfo(count=1)
80
+ children.append(child)
81
+
82
+ if fallback and msg:
83
+ # Warn/raise the user if any children were collapsed
84
+ # and the "fallback_mode" configuration is not "silent"
85
+ _fallback_inform(msg, rec.state["config_options"])
86
+
87
+ # Reconstruct and return
88
+ new_node = ir.reconstruct(children)
89
+ partition_info[new_node] = PartitionInfo(count=1)
90
+ return new_node, partition_info
91
+
92
+
93
+ def _leaf_column_names(expr: Expr) -> tuple[str, ...]:
94
+ """Find the leaf column names of an expression."""
95
+ if expr.children:
96
+ return tuple(
97
+ chain.from_iterable(_leaf_column_names(child) for child in expr.children)
98
+ )
99
+ elif isinstance(expr, Col):
100
+ return (expr.name,)
101
+ else:
102
+ return ()
103
+
104
+
105
+ def _get_unique_fractions(
106
+ column_names: Sequence[str],
107
+ user_unique_fractions: dict[str, float],
108
+ *,
109
+ row_count: ColumnStat[int] | None = None,
110
+ column_stats: dict[str, ColumnStats] | None = None,
111
+ ) -> dict[str, float]:
112
+ """
113
+ Return unique-fraction statistics subset.
114
+
115
+ Parameters
116
+ ----------
117
+ column_names
118
+ The column names to get unique-fractions for.
119
+ user_unique_fractions
120
+ The user-provided unique-fraction dictionary.
121
+ row_count
122
+ Row-count statistics. This will be None if
123
+ statistics planning is not enabled.
124
+ column_stats
125
+ The column statistics. This will be None if
126
+ statistics planning is not enabled.
127
+
128
+ Returns
129
+ -------
130
+ unique_fractions
131
+ The final unique-fraction dictionary.
132
+ """
133
+ unique_fractions: dict[str, float] = {}
134
+ column_stats = column_stats or {}
135
+ row_count = row_count or ColumnStat[int](None)
136
+ if isinstance(row_count.value, int) and row_count.value > 0:
137
+ for c in set(column_names).intersection(column_stats):
138
+ if (unique_count := column_stats[c].unique_count.value) is not None:
139
+ # Use unique_count_estimate (if available)
140
+ unique_fractions[c] = max(
141
+ min(1.0, unique_count / row_count.value),
142
+ 0.00001,
143
+ )
144
+
145
+ # Update with user-provided unique-fractions
146
+ unique_fractions.update(
147
+ {
148
+ c: max(min(f, 1.0), 0.00001)
149
+ for c, f in user_unique_fractions.items()
150
+ if c in column_names
151
+ }
152
+ )
153
+ return unique_fractions
154
+
155
+
156
+ def _contains_over(exprs: Sequence[Expr]) -> bool:
157
+ """Return True if any expression in 'exprs' contains an over(...) (ie. GroupedRollingWindow)."""
158
+ return any(isinstance(e, GroupedRollingWindow) for e in traversal(exprs))
159
+
160
+
161
+ def _contains_unsupported_fill_strategy(exprs: Sequence[Expr]) -> bool:
162
+ for e in traversal(exprs):
163
+ if (
164
+ isinstance(e, UnaryFunction)
165
+ and e.name == "fill_null_with_strategy"
166
+ and e.options[0] not in ("zero", "one")
167
+ ):
168
+ return True
169
+ return False
cudf_polars/py.typed ADDED
File without changes
@@ -0,0 +1,8 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Testing utilities for cudf_polars."""
5
+
6
+ from __future__ import annotations
7
+
8
+ __all__: list[str] = []
@@ -0,0 +1,448 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Device-aware assertions."""
5
+
6
+ from __future__ import annotations
7
+
8
+ from pathlib import Path
9
+ from typing import TYPE_CHECKING, Any, Literal
10
+
11
+ import polars as pl
12
+ from polars import GPUEngine
13
+ from polars.testing.asserts import assert_frame_equal
14
+
15
+ from cudf_polars.dsl.translate import Translator
16
+ from cudf_polars.utils.config import ConfigOptions, StreamingFallbackMode
17
+ from cudf_polars.utils.versions import POLARS_VERSION_LT_1323
18
+
19
+ if TYPE_CHECKING:
20
+ from cudf_polars.typing import OptimizationArgs
21
+
22
+
23
+ __all__: list[str] = [
24
+ "assert_gpu_result_equal",
25
+ "assert_ir_translation_raises",
26
+ "assert_sink_ir_translation_raises",
27
+ "assert_sink_result_equal",
28
+ ]
29
+
30
+ # Will be overriden by `conftest.py` with the value from the `--executor`
31
+ # and `--scheduler` command-line arguments
32
+ DEFAULT_EXECUTOR = "in-memory"
33
+ DEFAULT_SCHEDULER = "synchronous"
34
+ DEFAULT_BLOCKSIZE_MODE: Literal["small", "default"] = "default"
35
+
36
+
37
+ def assert_gpu_result_equal(
38
+ lazydf: pl.LazyFrame,
39
+ *,
40
+ engine: GPUEngine | None = None,
41
+ collect_kwargs: dict[OptimizationArgs, bool] | None = None,
42
+ polars_collect_kwargs: dict[OptimizationArgs, bool] | None = None,
43
+ cudf_collect_kwargs: dict[OptimizationArgs, bool] | None = None,
44
+ check_row_order: bool = True,
45
+ check_column_order: bool = True,
46
+ check_dtypes: bool = True,
47
+ check_exact: bool = True,
48
+ rtol: float = 1e-05,
49
+ atol: float = 1e-08,
50
+ categorical_as_str: bool = False,
51
+ executor: str | None = None,
52
+ blocksize_mode: Literal["small", "default"] | None = None,
53
+ ) -> None:
54
+ """
55
+ Assert that collection of a lazyframe on GPU produces correct results.
56
+
57
+ Parameters
58
+ ----------
59
+ lazydf
60
+ frame to collect.
61
+ engine
62
+ Custom GPU engine configuration.
63
+ collect_kwargs
64
+ Common keyword arguments to pass to collect for both polars CPU and
65
+ cudf-polars.
66
+ Useful for controlling optimization settings.
67
+ polars_collect_kwargs
68
+ Keyword arguments to pass to collect for execution on polars CPU.
69
+ Overrides kwargs in collect_kwargs.
70
+ Useful for controlling optimization settings.
71
+ cudf_collect_kwargs
72
+ Keyword arguments to pass to collect for execution on cudf-polars.
73
+ Overrides kwargs in collect_kwargs.
74
+ Useful for controlling optimization settings.
75
+ check_row_order
76
+ Expect rows to be in same order
77
+ check_column_order
78
+ Expect columns to be in same order
79
+ check_dtypes
80
+ Expect dtypes to match
81
+ check_exact
82
+ Require exact equality for floats, if `False` compare using
83
+ rtol and atol.
84
+ rtol
85
+ Relative tolerance for float comparisons
86
+ atol
87
+ Absolute tolerance for float comparisons
88
+ categorical_as_str
89
+ Decat categoricals to strings before comparing
90
+ executor
91
+ The executor configuration to pass to `GPUEngine`. If not specified
92
+ uses the module level `Executor` attribute.
93
+ blocksize_mode
94
+ The "mode" to use for choosing the blocksize for the streaming executor.
95
+ If not specified, uses the module level ``DEFAULT_BLOCKSIZE_MODE`` attribute.
96
+ Set to "small" to configure small values for ``max_rows_per_partition``
97
+ and ``target_partition_size``, which will typically cause many partitions
98
+ to be created while executing the query.
99
+
100
+ Raises
101
+ ------
102
+ AssertionError
103
+ If the GPU and CPU collection do not match.
104
+ NotImplementedError
105
+ If GPU collection failed in some way.
106
+ """
107
+ engine = engine or get_default_engine(executor, blocksize_mode)
108
+ final_polars_collect_kwargs, final_cudf_collect_kwargs = _process_kwargs(
109
+ collect_kwargs, polars_collect_kwargs, cudf_collect_kwargs
110
+ )
111
+
112
+ # These keywords are correct, but mypy doesn't see that.
113
+ # the 'misc' is for 'error: Keywords must be strings'
114
+ expect = lazydf.collect(**final_polars_collect_kwargs) # type: ignore[call-overload,misc]
115
+ got = lazydf.collect(**final_cudf_collect_kwargs, engine=engine) # type: ignore[call-overload,misc]
116
+
117
+ assert_kwargs_bool: dict[str, bool] = {
118
+ "check_row_order": check_row_order,
119
+ "check_column_order": check_column_order,
120
+ "check_dtypes": check_dtypes,
121
+ "check_exact": check_exact,
122
+ "categorical_as_str": categorical_as_str,
123
+ }
124
+
125
+ tol_kwargs: dict[str, float]
126
+ if POLARS_VERSION_LT_1323: # pragma: no cover
127
+ tol_kwargs = {"rtol": rtol, "atol": atol}
128
+ else:
129
+ tol_kwargs = {"rel_tol": rtol, "abs_tol": atol}
130
+
131
+ assert_frame_equal(
132
+ expect,
133
+ got,
134
+ **assert_kwargs_bool,
135
+ **tol_kwargs,
136
+ )
137
+
138
+
139
+ def assert_ir_translation_raises(q: pl.LazyFrame, *exceptions: type[Exception]) -> None:
140
+ """
141
+ Assert that translation of a query raises an exception.
142
+
143
+ Parameters
144
+ ----------
145
+ q
146
+ Query to translate.
147
+ exceptions
148
+ Exceptions that one expects might be raised.
149
+
150
+ Returns
151
+ -------
152
+ None
153
+ If translation successfully raised the specified exceptions.
154
+
155
+ Raises
156
+ ------
157
+ AssertionError
158
+ If the specified exceptions were not raised.
159
+ """
160
+ translator = Translator(q._ldf.visit(), GPUEngine())
161
+ translator.translate_ir()
162
+ if errors := translator.errors:
163
+ for err in errors:
164
+ assert any(isinstance(err, err_type) for err_type in exceptions), (
165
+ f"Translation DID NOT RAISE {exceptions}. The following "
166
+ f"errors were seen instead: {errors}"
167
+ )
168
+ return
169
+ else:
170
+ raise AssertionError(f"Translation DID NOT RAISE {exceptions}")
171
+
172
+
173
+ def get_default_engine(
174
+ executor: str | None = None,
175
+ blocksize_mode: Literal["small", "default"] | None = None,
176
+ ) -> GPUEngine:
177
+ """
178
+ Get the default engine used for testing.
179
+
180
+ Parameters
181
+ ----------
182
+ executor
183
+ The executor configuration to pass to `GPUEngine`. If not specified
184
+ uses the module level `Executor` attribute.
185
+ blocksize_mode
186
+ The "mode" to use for choosing the blocksize for the streaming executor.
187
+ If not specified, uses the module level ``DEFAULT_BLOCKSIZE_MODE`` attribute.
188
+ Set to "small" to configure small values for ``max_rows_per_partition``
189
+ and ``target_partition_size``, which will typically cause many partitions
190
+ to be created while executing the query.
191
+
192
+ Returns
193
+ -------
194
+ engine
195
+ A polars GPUEngine configured with the default settings for tests.
196
+
197
+ See Also
198
+ --------
199
+ assert_gpu_result_equal
200
+ assert_sink_result_equal
201
+ """
202
+ executor_options: dict[str, Any] = {}
203
+ executor = executor or DEFAULT_EXECUTOR
204
+ if executor == "streaming":
205
+ executor_options["scheduler"] = DEFAULT_SCHEDULER
206
+
207
+ blocksize_mode = blocksize_mode or DEFAULT_BLOCKSIZE_MODE
208
+
209
+ if blocksize_mode == "small": # pragma: no cover
210
+ executor_options["max_rows_per_partition"] = 4
211
+ executor_options["target_partition_size"] = 10
212
+ # We expect many tests to fall back, so silence the warnings
213
+ executor_options["fallback_mode"] = StreamingFallbackMode.SILENT
214
+
215
+ return GPUEngine(
216
+ raise_on_fail=True,
217
+ executor=executor,
218
+ executor_options=executor_options,
219
+ )
220
+
221
+
222
+ def _process_kwargs(
223
+ collect_kwargs: dict[OptimizationArgs, bool] | None,
224
+ polars_collect_kwargs: dict[OptimizationArgs, bool] | None,
225
+ cudf_collect_kwargs: dict[OptimizationArgs, bool] | None,
226
+ ) -> tuple[dict[OptimizationArgs, bool], dict[OptimizationArgs, bool]]:
227
+ if collect_kwargs is None:
228
+ collect_kwargs = {}
229
+ final_polars_collect_kwargs = collect_kwargs.copy()
230
+ final_cudf_collect_kwargs = collect_kwargs.copy()
231
+ if polars_collect_kwargs is not None: # pragma: no cover; not currently used
232
+ final_polars_collect_kwargs.update(polars_collect_kwargs)
233
+ if cudf_collect_kwargs is not None: # pragma: no cover; not currently used
234
+ final_cudf_collect_kwargs.update(cudf_collect_kwargs)
235
+ return final_polars_collect_kwargs, final_cudf_collect_kwargs
236
+
237
+
238
+ def assert_collect_raises(
239
+ lazydf: pl.LazyFrame,
240
+ *,
241
+ polars_except: type[Exception] | tuple[type[Exception], ...],
242
+ cudf_except: type[Exception] | tuple[type[Exception], ...],
243
+ collect_kwargs: dict[OptimizationArgs, bool] | None = None,
244
+ polars_collect_kwargs: dict[OptimizationArgs, bool] | None = None,
245
+ cudf_collect_kwargs: dict[OptimizationArgs, bool] | None = None,
246
+ ) -> None:
247
+ """
248
+ Assert that collecting the result of a query raises the expected exceptions.
249
+
250
+ Parameters
251
+ ----------
252
+ lazydf
253
+ frame to collect.
254
+ collect_kwargs
255
+ Common keyword arguments to pass to collect for both polars CPU and
256
+ cudf-polars.
257
+ Useful for controlling optimization settings.
258
+ polars_except
259
+ Exception or exceptions polars CPU is expected to raise. If
260
+ an empty tuple ``()``, CPU is expected to succeed without raising.
261
+ cudf_except
262
+ Exception or exceptions polars GPU is expected to raise. If
263
+ an empty tuple ``()``, GPU is expected to succeed without raising.
264
+ collect_kwargs
265
+ Common keyword arguments to pass to collect for both polars CPU and
266
+ cudf-polars.
267
+ Useful for controlling optimization settings.
268
+ polars_collect_kwargs
269
+ Keyword arguments to pass to collect for execution on polars CPU.
270
+ Overrides kwargs in collect_kwargs.
271
+ Useful for controlling optimization settings.
272
+ cudf_collect_kwargs
273
+ Keyword arguments to pass to collect for execution on cudf-polars.
274
+ Overrides kwargs in collect_kwargs.
275
+ Useful for controlling optimization settings.
276
+
277
+ Returns
278
+ -------
279
+ None
280
+ If both sides raise the expected exceptions.
281
+
282
+ Raises
283
+ ------
284
+ AssertionError
285
+ If either side did not raise the expected exceptions.
286
+ """
287
+ final_polars_collect_kwargs, final_cudf_collect_kwargs = _process_kwargs(
288
+ collect_kwargs, polars_collect_kwargs, cudf_collect_kwargs
289
+ )
290
+
291
+ try:
292
+ lazydf.collect(**final_polars_collect_kwargs) # type: ignore[call-overload,misc]
293
+ except polars_except:
294
+ pass
295
+ except Exception as e:
296
+ raise AssertionError(
297
+ f"CPU execution RAISED {type(e)}, EXPECTED {polars_except}"
298
+ ) from e
299
+ else:
300
+ if polars_except != ():
301
+ raise AssertionError(f"CPU execution DID NOT RAISE {polars_except}")
302
+
303
+ engine = GPUEngine(raise_on_fail=True)
304
+ try:
305
+ lazydf.collect(**final_cudf_collect_kwargs, engine=engine) # type: ignore[call-overload,misc]
306
+ except cudf_except:
307
+ pass
308
+ except Exception as e:
309
+ raise AssertionError(
310
+ f"GPU execution RAISED {type(e)}, EXPECTED {cudf_except}"
311
+ ) from e
312
+ else:
313
+ if cudf_except != ():
314
+ raise AssertionError(f"GPU execution DID NOT RAISE {cudf_except}")
315
+
316
+
317
+ def _resolve_sink_format(path: Path) -> str:
318
+ """Returns valid sink format for assert utilities."""
319
+ suffix = path.suffix.lower()
320
+ supported_ext = {
321
+ ".csv": "csv",
322
+ ".pq": "parquet",
323
+ ".parquet": "parquet",
324
+ ".json": "ndjson",
325
+ ".ndjson": "ndjson",
326
+ }
327
+ if suffix not in supported_ext:
328
+ raise ValueError(f"Unsupported file format: {suffix}")
329
+ return supported_ext[suffix]
330
+
331
+
332
+ def assert_sink_result_equal(
333
+ lazydf: pl.LazyFrame,
334
+ path: str | Path,
335
+ *,
336
+ engine: str | GPUEngine | None = None,
337
+ read_kwargs: dict | None = None,
338
+ write_kwargs: dict | None = None,
339
+ executor: str | None = None,
340
+ blocksize_mode: Literal["small", "default"] | None = None,
341
+ ) -> None:
342
+ """
343
+ Assert that writing a LazyFrame via sink produces the same output.
344
+
345
+ Parameters
346
+ ----------
347
+ lazydf
348
+ The LazyFrame to sink.
349
+ path
350
+ The file path to use. Suffix must be one of:
351
+ '.csv', '.parquet', '.pq', '.json', '.ndjson'.
352
+ engine
353
+ The GPU engine to use for the sink operation.
354
+ read_kwargs
355
+ Optional keyword arguments to pass to the corresponding `pl.read_*` function.
356
+ write_kwargs
357
+ Optional keyword arguments to pass to the corresponding `sink_*` function.
358
+ executor
359
+ The executor configuration to pass to `GPUEngine`. If not specified
360
+ uses the module level `Executor` attribute.
361
+ blocksize_mode
362
+ The "mode" to use for choosing the blocksize for the streaming executor.
363
+ If not specified, uses the module level ``DEFAULT_BLOCKSIZE_MODE`` attribute.
364
+ Set to "small" to configure small values for ``max_rows_per_partition``
365
+ and ``target_partition_size``, which will typically cause many partitions
366
+ to be created while executing the query.
367
+
368
+ Raises
369
+ ------
370
+ AssertionError
371
+ If the outputs from CPU and GPU sink differ.
372
+ ValueError
373
+ If the file extension is not one of the supported formats.
374
+ """
375
+ engine = engine or get_default_engine(executor, blocksize_mode)
376
+ path = Path(path)
377
+ read_kwargs = read_kwargs or {}
378
+ write_kwargs = write_kwargs or {}
379
+
380
+ fmt = _resolve_sink_format(path)
381
+
382
+ cpu_path = path.with_name(f"{path.stem}_cpu{path.suffix}")
383
+ gpu_path = path.with_name(f"{path.stem}_gpu{path.suffix}")
384
+
385
+ sink_fn = getattr(lazydf, f"sink_{fmt}")
386
+ read_fn = getattr(pl, f"read_{fmt}")
387
+
388
+ sink_fn(cpu_path, **write_kwargs)
389
+ sink_fn(gpu_path, engine=engine, **write_kwargs)
390
+
391
+ expected = read_fn(cpu_path, **read_kwargs)
392
+ # the multi-partition executor might produce multiple files, one per partition.
393
+ if (
394
+ isinstance(engine, GPUEngine)
395
+ and ConfigOptions.from_polars_engine(engine).executor.name == "streaming"
396
+ and gpu_path.is_dir()
397
+ ): # pragma: no cover
398
+ result = read_fn(gpu_path.joinpath("*"), **read_kwargs)
399
+ else:
400
+ result = read_fn(gpu_path, **read_kwargs)
401
+
402
+ assert_frame_equal(expected, result)
403
+
404
+
405
+ def assert_sink_ir_translation_raises(
406
+ lazydf: pl.LazyFrame,
407
+ path: str | Path,
408
+ write_kwargs: dict,
409
+ *exceptions: type[Exception],
410
+ ) -> None:
411
+ """
412
+ Assert that translation of a sink query raises an exception.
413
+
414
+ Parameters
415
+ ----------
416
+ lazydf
417
+ The LazyFrame to sink.
418
+ path
419
+ The file path. Must have one of the supported suffixes.
420
+ write_kwargs
421
+ Keyword arguments to pass to the `sink_*` method.
422
+ *exceptions
423
+ One or more expected exception types that should be raised during translation.
424
+
425
+ Raises
426
+ ------
427
+ AssertionError
428
+ If translation does not raise any of the expected exceptions.
429
+ If an exception occurs before translation begins.
430
+ ValueError
431
+ If the file extension is not one of the supported formats.
432
+ """
433
+ path = Path(path)
434
+ fmt = _resolve_sink_format(path)
435
+
436
+ try:
437
+ lazy_sink = getattr(lazydf, f"sink_{fmt}")(
438
+ path,
439
+ engine="gpu",
440
+ lazy=True,
441
+ **write_kwargs,
442
+ )
443
+ except Exception as e:
444
+ raise AssertionError(
445
+ f"Sink function raised an exception before translation: {e}"
446
+ ) from e
447
+
448
+ assert_ir_translation_raises(lazy_sink, *exceptions)