cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. cudf_polars/GIT_COMMIT +1 -1
  2. cudf_polars/VERSION +1 -1
  3. cudf_polars/callback.py +60 -15
  4. cudf_polars/containers/column.py +137 -77
  5. cudf_polars/containers/dataframe.py +123 -34
  6. cudf_polars/containers/datatype.py +134 -13
  7. cudf_polars/dsl/expr.py +0 -2
  8. cudf_polars/dsl/expressions/aggregation.py +80 -28
  9. cudf_polars/dsl/expressions/binaryop.py +34 -14
  10. cudf_polars/dsl/expressions/boolean.py +110 -37
  11. cudf_polars/dsl/expressions/datetime.py +59 -30
  12. cudf_polars/dsl/expressions/literal.py +11 -5
  13. cudf_polars/dsl/expressions/rolling.py +460 -119
  14. cudf_polars/dsl/expressions/selection.py +9 -8
  15. cudf_polars/dsl/expressions/slicing.py +1 -1
  16. cudf_polars/dsl/expressions/string.py +256 -114
  17. cudf_polars/dsl/expressions/struct.py +19 -7
  18. cudf_polars/dsl/expressions/ternary.py +33 -3
  19. cudf_polars/dsl/expressions/unary.py +126 -64
  20. cudf_polars/dsl/ir.py +1053 -350
  21. cudf_polars/dsl/to_ast.py +30 -13
  22. cudf_polars/dsl/tracing.py +194 -0
  23. cudf_polars/dsl/translate.py +307 -107
  24. cudf_polars/dsl/utils/aggregations.py +43 -30
  25. cudf_polars/dsl/utils/reshape.py +14 -2
  26. cudf_polars/dsl/utils/rolling.py +12 -8
  27. cudf_polars/dsl/utils/windows.py +35 -20
  28. cudf_polars/experimental/base.py +55 -2
  29. cudf_polars/experimental/benchmarks/pdsds.py +12 -126
  30. cudf_polars/experimental/benchmarks/pdsh.py +792 -2
  31. cudf_polars/experimental/benchmarks/utils.py +596 -39
  32. cudf_polars/experimental/dask_registers.py +47 -20
  33. cudf_polars/experimental/dispatch.py +9 -3
  34. cudf_polars/experimental/distinct.py +2 -0
  35. cudf_polars/experimental/explain.py +15 -2
  36. cudf_polars/experimental/expressions.py +30 -15
  37. cudf_polars/experimental/groupby.py +25 -4
  38. cudf_polars/experimental/io.py +156 -124
  39. cudf_polars/experimental/join.py +53 -23
  40. cudf_polars/experimental/parallel.py +68 -19
  41. cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
  42. cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
  43. cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
  44. cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
  45. cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
  46. cudf_polars/experimental/rapidsmpf/core.py +488 -0
  47. cudf_polars/experimental/rapidsmpf/dask.py +172 -0
  48. cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
  49. cudf_polars/experimental/rapidsmpf/io.py +696 -0
  50. cudf_polars/experimental/rapidsmpf/join.py +322 -0
  51. cudf_polars/experimental/rapidsmpf/lower.py +74 -0
  52. cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
  53. cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
  54. cudf_polars/experimental/rapidsmpf/union.py +115 -0
  55. cudf_polars/experimental/rapidsmpf/utils.py +374 -0
  56. cudf_polars/experimental/repartition.py +9 -2
  57. cudf_polars/experimental/select.py +177 -14
  58. cudf_polars/experimental/shuffle.py +46 -12
  59. cudf_polars/experimental/sort.py +100 -26
  60. cudf_polars/experimental/spilling.py +1 -1
  61. cudf_polars/experimental/statistics.py +24 -5
  62. cudf_polars/experimental/utils.py +25 -7
  63. cudf_polars/testing/asserts.py +13 -8
  64. cudf_polars/testing/io.py +2 -1
  65. cudf_polars/testing/plugin.py +93 -17
  66. cudf_polars/typing/__init__.py +86 -32
  67. cudf_polars/utils/config.py +473 -58
  68. cudf_polars/utils/cuda_stream.py +70 -0
  69. cudf_polars/utils/versions.py +5 -4
  70. cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
  71. cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
  72. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
  73. cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
  74. cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
  75. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
  76. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,10 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  """Sorting Logic."""
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
+ from functools import partial
7
8
  from typing import TYPE_CHECKING, Any, Literal, TypedDict
8
9
 
9
10
  import polars as pl
@@ -22,11 +23,19 @@ from cudf_polars.experimental.repartition import Repartition
22
23
  from cudf_polars.experimental.shuffle import _simple_shuffle_graph
23
24
  from cudf_polars.experimental.utils import _concat, _fallback_inform, _lower_ir_fallback
24
25
  from cudf_polars.utils.config import ShuffleMethod
26
+ from cudf_polars.utils.cuda_stream import (
27
+ get_dask_cuda_stream,
28
+ get_joined_cuda_stream,
29
+ join_cuda_streams,
30
+ )
25
31
 
26
32
  if TYPE_CHECKING:
27
33
  from collections.abc import MutableMapping, Sequence
28
34
 
35
+ from rmm.pylibrmm.stream import Stream
36
+
29
37
  from cudf_polars.dsl.expr import NamedExpr
38
+ from cudf_polars.dsl.ir import IRExecutionContext
30
39
  from cudf_polars.experimental.dispatch import LowerIRTransformer
31
40
  from cudf_polars.typing import Schema
32
41
 
@@ -37,6 +46,7 @@ def find_sort_splits(
37
46
  my_part_id: int,
38
47
  column_order: Sequence[plc.types.Order],
39
48
  null_order: Sequence[plc.types.NullOrder],
49
+ stream: Stream,
40
50
  ) -> list[int]:
41
51
  """
42
52
  Find local sort splits given all (global) split candidates.
@@ -59,6 +69,10 @@ def find_sort_splits(
59
69
  The order in which tbl is sorted.
60
70
  null_order
61
71
  The null order in which tbl is sorted.
72
+ stream
73
+ CUDA stream used for device memory operations and kernel launches.
74
+ The values in both ``tbl`` and ``sort_boundaries`` must be valid on
75
+ ``stream``.
62
76
 
63
77
  Returns
64
78
  -------
@@ -69,28 +83,44 @@ def find_sort_splits(
69
83
 
70
84
  # We now need to find the local split points. To do this, first split out
71
85
  # the partition id and the local row number of the final split values
72
- *sort_boundaries, split_part_id, split_local_row = sort_boundaries.columns()
73
- sort_boundaries = plc.Table(sort_boundaries)
86
+ *boundary_cols, split_part_id, split_local_row = sort_boundaries.columns()
87
+ sort_boundaries = plc.Table(boundary_cols)
74
88
  # Now we find the first and last row in the local table corresponding to the split value
75
89
  # (first and last, because there may be multiple rows with the same split value)
76
90
  split_first_col = plc.search.lower_bound(
77
- tbl, sort_boundaries, column_order, null_order
91
+ tbl,
92
+ sort_boundaries,
93
+ column_order,
94
+ null_order,
95
+ stream=stream,
78
96
  )
79
97
  split_last_col = plc.search.upper_bound(
80
- tbl, sort_boundaries, column_order, null_order
98
+ tbl,
99
+ sort_boundaries,
100
+ column_order,
101
+ null_order,
102
+ stream=stream,
81
103
  )
82
104
  # And convert to list for final processing
83
- split_first_col = pl.Series(split_first_col).to_list()
84
- split_last_col = pl.Series(split_last_col).to_list()
85
- split_part_id = pl.Series(split_part_id).to_list()
86
- split_local_row = pl.Series(split_local_row).to_list()
105
+ # The type ignores are for cross-library boundaries: plc.Column -> pl.Series
106
+ # These work at runtime via the Arrow C Data Interface protocol
107
+ # TODO: Find a way for pylibcudf types to show they export the Arrow protocol
108
+ # (mypy wasn't happy with a custom protocol)
109
+ split_first_list = pl.Series(split_first_col).to_list()
110
+ split_last_list = pl.Series(split_last_col).to_list()
111
+ split_part_id_list = pl.Series(split_part_id).to_list()
112
+ split_local_row_list = pl.Series(split_local_row).to_list()
87
113
 
88
114
  # Find the final split points. This is slightly tricky because of the possibility
89
115
  # of equal values, which is why we need the part_id and local_row.
90
116
  # Consider for example the case when all data is equal.
91
117
  split_points = []
92
118
  for first, last, part_id, local_row in zip(
93
- split_first_col, split_last_col, split_part_id, split_local_row, strict=False
119
+ split_first_list,
120
+ split_last_list,
121
+ split_part_id_list,
122
+ split_local_row_list,
123
+ strict=False,
94
124
  ):
95
125
  if part_id < my_part_id:
96
126
  # Local data is globally later so split at first valid row.
@@ -126,31 +156,42 @@ def _select_local_split_candidates(
126
156
  [
127
157
  *df.columns,
128
158
  Column(
129
- plc.column_factories.make_empty_column(part_id_dtype.plc),
159
+ plc.column_factories.make_empty_column(
160
+ part_id_dtype.plc_type, stream=df.stream
161
+ ),
130
162
  dtype=part_id_dtype,
131
163
  name=next(name_gen),
132
164
  ),
133
165
  Column(
134
- plc.column_factories.make_empty_column(part_id_dtype.plc),
166
+ plc.column_factories.make_empty_column(
167
+ part_id_dtype.plc_type, stream=df.stream
168
+ ),
135
169
  dtype=part_id_dtype,
136
170
  name=next(name_gen),
137
171
  ),
138
- ]
172
+ ],
173
+ stream=df.stream,
139
174
  )
140
175
 
141
176
  candidates = [i * df.num_rows // num_partitions for i in range(num_partitions)]
142
- row_id = plc.Column.from_iterable_of_py(candidates, part_id_dtype.plc)
177
+ row_id = plc.Column.from_iterable_of_py(
178
+ candidates, part_id_dtype.plc_type, stream=df.stream
179
+ )
143
180
 
144
- res = plc.copying.gather(df.table, row_id, plc.copying.OutOfBoundsPolicy.DONT_CHECK)
181
+ res = plc.copying.gather(
182
+ df.table, row_id, plc.copying.OutOfBoundsPolicy.DONT_CHECK, stream=df.stream
183
+ )
145
184
  part_id = plc.Column.from_scalar(
146
- plc.Scalar.from_py(my_part_id, part_id_dtype.plc),
185
+ plc.Scalar.from_py(my_part_id, part_id_dtype.plc_type, stream=df.stream),
147
186
  len(candidates),
187
+ stream=df.stream,
148
188
  )
149
189
 
150
190
  return DataFrame.from_table(
151
191
  plc.Table([*res.columns(), part_id, row_id]),
152
192
  [*df.column_names, next(name_gen), next(name_gen)],
153
193
  [*df.dtypes, part_id_dtype, part_id_dtype],
194
+ stream=df.stream,
154
195
  )
155
196
 
156
197
 
@@ -159,7 +200,7 @@ def _get_final_sort_boundaries(
159
200
  column_order: Sequence[plc.types.Order],
160
201
  null_order: Sequence[plc.types.NullOrder],
161
202
  num_partitions: int,
162
- ) -> plc.Table:
203
+ ) -> DataFrame:
163
204
  """
164
205
  Find the global sort split boundaries from all gathered split candidates.
165
206
 
@@ -186,22 +227,28 @@ def _get_final_sort_boundaries(
186
227
  # split candidates has the additional partition_id and row_number columns
187
228
  column_order + [plc.types.Order.ASCENDING] * 2,
188
229
  null_order + [plc.types.NullOrder.AFTER] * 2,
230
+ stream=sort_boundaries_candidates.stream,
189
231
  )
190
232
  selected_candidates = plc.Column.from_iterable_of_py(
191
233
  [
192
234
  i * sorted_candidates.num_rows() // num_partitions
193
235
  for i in range(1, num_partitions)
194
- ]
236
+ ],
237
+ stream=sort_boundaries_candidates.stream,
195
238
  )
196
239
  # Get the actual values at which we will split the data
197
240
  sort_boundaries = plc.copying.gather(
198
- sorted_candidates, selected_candidates, plc.copying.OutOfBoundsPolicy.DONT_CHECK
241
+ sorted_candidates,
242
+ selected_candidates,
243
+ plc.copying.OutOfBoundsPolicy.DONT_CHECK,
244
+ stream=sort_boundaries_candidates.stream,
199
245
  )
200
246
 
201
247
  return DataFrame.from_table(
202
248
  sort_boundaries,
203
249
  sort_boundaries_candidates.column_names,
204
250
  sort_boundaries_candidates.dtypes,
251
+ stream=sort_boundaries_candidates.stream,
205
252
  )
206
253
 
207
254
 
@@ -211,6 +258,7 @@ def _sort_boundaries_graph(
211
258
  column_order: Sequence[plc.types.Order],
212
259
  null_order: Sequence[plc.types.NullOrder],
213
260
  count: int,
261
+ context: IRExecutionContext,
214
262
  ) -> tuple[str, MutableMapping[Any, Any]]:
215
263
  """Graph to get the boundaries from all partitions."""
216
264
  local_boundaries_name = f"sort-boundaries_local-{name_in}"
@@ -229,7 +277,7 @@ def _sort_boundaries_graph(
229
277
  )
230
278
  _concat_list.append((local_boundaries_name, part_id))
231
279
 
232
- graph[concat_boundaries_name] = (_concat, *_concat_list)
280
+ graph[concat_boundaries_name] = (partial(_concat, context=context), *_concat_list)
233
281
  graph[global_boundaries_name] = (
234
282
  _get_final_sort_boundaries,
235
283
  concat_boundaries_name,
@@ -276,6 +324,11 @@ class RMPFIntegrationSortedShuffle: # pragma: no cover
276
324
  context = get_worker_context()
277
325
 
278
326
  by = options["by"]
327
+ data_streams = [
328
+ df.stream,
329
+ sort_boundaries.stream,
330
+ ]
331
+ stream = get_joined_cuda_stream(get_dask_cuda_stream, upstreams=data_streams)
279
332
 
280
333
  splits = find_sort_splits(
281
334
  df.select(by).table,
@@ -283,15 +336,20 @@ class RMPFIntegrationSortedShuffle: # pragma: no cover
283
336
  partition_id,
284
337
  options["order"],
285
338
  options["null_order"],
339
+ stream=stream,
286
340
  )
287
341
  packed_inputs = split_and_pack(
288
342
  df.table,
289
343
  splits=splits,
290
344
  br=context.br,
291
- stream=DEFAULT_STREAM,
345
+ stream=stream,
292
346
  )
347
+ # TODO: figure out handoff with rapidsmpf
348
+ # https://github.com/rapidsai/cudf/issues/20337
293
349
  shuffler.insert_chunks(packed_inputs)
294
350
 
351
+ join_cuda_streams(downstreams=data_streams, upstreams=[stream])
352
+
295
353
  @staticmethod
296
354
  def extract_partition(
297
355
  partition_id: int,
@@ -316,8 +374,12 @@ class RMPFIntegrationSortedShuffle: # pragma: no cover
316
374
  column_names = options["column_names"]
317
375
  column_dtypes = options["column_dtypes"]
318
376
 
377
+ stream = DEFAULT_STREAM
378
+
319
379
  # TODO: When sorting, this step should finalize with a merge (unless we
320
380
  # require stability, as cudf merge is not stable).
381
+ # TODO: figure out handoff with rapidsmpf
382
+ # https://github.com/rapidsai/cudf/issues/20337
321
383
  return DataFrame.from_table(
322
384
  unpack_and_concat(
323
385
  unspill_partitions(
@@ -327,10 +389,11 @@ class RMPFIntegrationSortedShuffle: # pragma: no cover
327
389
  statistics=context.statistics,
328
390
  ),
329
391
  br=context.br,
330
- stream=DEFAULT_STREAM,
392
+ stream=stream,
331
393
  ),
332
394
  column_names,
333
395
  column_dtypes,
396
+ stream=stream,
334
397
  )
335
398
 
336
399
 
@@ -359,7 +422,11 @@ def _sort_partition_dataframe(
359
422
  """
360
423
  if df.num_rows == 0: # pragma: no cover
361
424
  # Fast path for empty DataFrame
362
- return {i: df for i in range(partition_count)}
425
+ return dict.fromkeys(range(partition_count), df)
426
+
427
+ stream = get_joined_cuda_stream(
428
+ get_dask_cuda_stream, upstreams=(df.stream, sort_boundaries.stream)
429
+ )
363
430
 
364
431
  splits = find_sort_splits(
365
432
  df.select(options["by"]).table,
@@ -367,6 +434,7 @@ def _sort_partition_dataframe(
367
434
  partition_id,
368
435
  options["order"],
369
436
  options["null_order"],
437
+ stream=stream,
370
438
  )
371
439
 
372
440
  # Split and return the partitioned result
@@ -375,8 +443,9 @@ def _sort_partition_dataframe(
375
443
  split,
376
444
  df.column_names,
377
445
  df.dtypes,
446
+ stream=df.stream,
378
447
  )
379
- for i, split in enumerate(plc.copying.split(df.table, splits))
448
+ for i, split in enumerate(plc.copying.split(df.table, splits, stream=stream))
380
449
  }
381
450
 
382
451
 
@@ -428,6 +497,8 @@ class ShuffleSorted(IR):
428
497
  null_order: tuple[plc.types.NullOrder, ...],
429
498
  shuffle_method: ShuffleMethod,
430
499
  df: DataFrame,
500
+ *,
501
+ context: IRExecutionContext,
431
502
  ) -> DataFrame: # pragma: no cover
432
503
  """Evaluate and return a dataframe."""
433
504
  # Single-partition ShuffleSorted evaluation is a no-op
@@ -532,7 +603,9 @@ def _(
532
603
 
533
604
  @generate_ir_tasks.register(ShuffleSorted)
534
605
  def _(
535
- ir: ShuffleSorted, partition_info: MutableMapping[IR, PartitionInfo]
606
+ ir: ShuffleSorted,
607
+ partition_info: MutableMapping[IR, PartitionInfo],
608
+ context: IRExecutionContext,
536
609
  ) -> MutableMapping[Any, Any]:
537
610
  by = [ne.value.name for ne in ir.by if isinstance(ne.value, Col)]
538
611
  if len(by) != len(ir.by): # pragma: no cover
@@ -547,6 +620,7 @@ def _(
547
620
  ir.order,
548
621
  ir.null_order,
549
622
  partition_info[child].count,
623
+ context,
550
624
  )
551
625
 
552
626
  options = {
@@ -596,7 +670,7 @@ def _(
596
670
 
597
671
  # Simple task-based fall-back
598
672
  graph.update(
599
- _simple_shuffle_graph(
673
+ partial(_simple_shuffle_graph, context=context)(
600
674
  get_key_name(child),
601
675
  get_key_name(ir),
602
676
  partition_info[child].count,
@@ -8,9 +8,9 @@ from typing import TYPE_CHECKING, Any
8
8
 
9
9
  from dask.sizeof import sizeof
10
10
  from distributed import get_worker
11
- from rapidsmpf.buffer.buffer import MemoryType
12
11
  from rapidsmpf.integrations.dask.core import get_worker_context
13
12
  from rapidsmpf.integrations.dask.spilling import SpillableWrapper
13
+ from rapidsmpf.memory.buffer import MemoryType
14
14
 
15
15
  from cudf_polars.containers import DataFrame
16
16
 
@@ -37,6 +37,7 @@ from cudf_polars.experimental.dispatch import (
37
37
  from cudf_polars.experimental.expressions import _SUPPORTED_AGGS
38
38
  from cudf_polars.experimental.utils import _leaf_column_names
39
39
  from cudf_polars.utils import conversion
40
+ from cudf_polars.utils.cuda_stream import get_cuda_stream
40
41
 
41
42
  if TYPE_CHECKING:
42
43
  from collections.abc import Mapping, Sequence
@@ -47,7 +48,10 @@ if TYPE_CHECKING:
47
48
  from cudf_polars.utils.config import ConfigOptions, StatsPlanningOptions
48
49
 
49
50
 
50
- def collect_statistics(root: IR, config_options: ConfigOptions) -> StatsCollector:
51
+ def collect_statistics(
52
+ root: IR,
53
+ config_options: ConfigOptions,
54
+ ) -> StatsCollector:
51
55
  """
52
56
  Collect column statistics for a query.
53
57
 
@@ -607,7 +611,12 @@ def _(ir: IR, stats: StatsCollector, config_options: ConfigOptions) -> None:
607
611
 
608
612
 
609
613
  @update_column_stats.register(DataFrameScan)
610
- def _(ir: DataFrameScan, stats: StatsCollector, config_options: ConfigOptions) -> None:
614
+ def _(
615
+ ir: DataFrameScan,
616
+ stats: StatsCollector,
617
+ config_options: ConfigOptions,
618
+ ) -> None:
619
+ stream = get_cuda_stream()
611
620
  # Use datasource row-count estimate.
612
621
  if stats.column_stats[ir]:
613
622
  stats.row_count[ir] = next(
@@ -620,15 +629,23 @@ def _(ir: DataFrameScan, stats: StatsCollector, config_options: ConfigOptions) -
620
629
  for column_stats in stats.column_stats[ir].values():
621
630
  if column_stats.source_info.implied_unique_count.value is None:
622
631
  # We don't have a unique-count estimate, so we need to sample the data.
623
- source_unique_stats = column_stats.source_info.unique_stats(force=False)
632
+ source_unique_stats = column_stats.source_info.unique_stats(
633
+ force=False,
634
+ )
624
635
  if source_unique_stats.count.value is not None:
625
636
  column_stats.unique_count = source_unique_stats.count
626
637
  else:
627
638
  column_stats.unique_count = column_stats.source_info.implied_unique_count
628
639
 
640
+ stream.synchronize()
641
+
629
642
 
630
643
  @update_column_stats.register(Scan)
631
- def _(ir: Scan, stats: StatsCollector, config_options: ConfigOptions) -> None:
644
+ def _(
645
+ ir: Scan,
646
+ stats: StatsCollector,
647
+ config_options: ConfigOptions,
648
+ ) -> None:
632
649
  # Use datasource row-count estimate.
633
650
  if stats.column_stats[ir]:
634
651
  stats.row_count[ir] = next(
@@ -649,7 +666,9 @@ def _(ir: Scan, stats: StatsCollector, config_options: ConfigOptions) -> None:
649
666
  for column_stats in stats.column_stats[ir].values():
650
667
  if column_stats.source_info.implied_unique_count.value is None:
651
668
  # We don't have a unique-count estimate, so we need to sample the data.
652
- source_unique_stats = column_stats.source_info.unique_stats(force=False)
669
+ source_unique_stats = column_stats.source_info.unique_stats(
670
+ force=False,
671
+ )
653
672
  if source_unique_stats.count.value is not None:
654
673
  column_stats.unique_count = source_unique_stats.count
655
674
  elif (
@@ -20,15 +20,15 @@ if TYPE_CHECKING:
20
20
 
21
21
  from cudf_polars.containers import DataFrame
22
22
  from cudf_polars.dsl.expr import Expr
23
- from cudf_polars.dsl.ir import IR
23
+ from cudf_polars.dsl.ir import IR, IRExecutionContext
24
24
  from cudf_polars.experimental.base import ColumnStats
25
25
  from cudf_polars.experimental.dispatch import LowerIRTransformer
26
26
  from cudf_polars.utils.config import ConfigOptions
27
27
 
28
28
 
29
- def _concat(*dfs: DataFrame) -> DataFrame:
29
+ def _concat(*dfs: DataFrame, context: IRExecutionContext) -> DataFrame:
30
30
  # Concatenate a sequence of DataFrames vertically
31
- return Union.do_evaluate(None, *dfs)
31
+ return dfs[0] if len(dfs) == 1 else Union.do_evaluate(None, *dfs, context=context)
32
32
 
33
33
 
34
34
  def _fallback_inform(msg: str, config_options: ConfigOptions) -> None:
@@ -63,23 +63,41 @@ def _lower_ir_fallback(
63
63
  # those children will be collapsed with `Repartition`.
64
64
  from cudf_polars.experimental.repartition import Repartition
65
65
 
66
+ # TODO: (IMPORTANT) Since Repartition is a local operation,
67
+ # the current fallback logic will only work for one rank!
68
+ # For multiple ranks, we will need to AllGather the data
69
+ # on all ranks.
70
+ config_options = rec.state["config_options"]
71
+ assert config_options.executor.name == "streaming", (
72
+ "'in-memory' executor not supported in 'generate_ir_sub_network'"
73
+ )
74
+ if (
75
+ (rapidsmpf_engine := config_options.executor.runtime == "rapidsmpf")
76
+ and config_options.executor.scheduler == "distributed"
77
+ ): # pragma: no cover; Requires distributed
78
+ raise NotImplementedError(
79
+ "Fallback is not yet supported distributed execution "
80
+ "with the RAPIDS-MPF streaming runtime."
81
+ )
82
+
66
83
  # Lower children
67
84
  lowered_children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True)
68
85
  partition_info = reduce(operator.or_, _partition_info)
69
86
 
70
87
  # Ensure all children are single-partitioned
71
88
  children = []
72
- fallback = False
89
+ inform = False
73
90
  for c in lowered_children:
74
91
  child = c
75
- if partition_info[c].count > 1:
92
+ if multi_partitioned := partition_info[c].count > 1:
93
+ inform = True
94
+ if multi_partitioned or rapidsmpf_engine:
76
95
  # Fall-back logic
77
- fallback = True
78
96
  child = Repartition(child.schema, child)
79
97
  partition_info[child] = PartitionInfo(count=1)
80
98
  children.append(child)
81
99
 
82
- if fallback and msg:
100
+ if inform and msg:
83
101
  # Warn/raise the user if any children were collapsed
84
102
  # and the "fallback_mode" configuration is not "silent"
85
103
  _fallback_inform(msg, rec.state["config_options"])
@@ -28,9 +28,10 @@ __all__: list[str] = [
28
28
  ]
29
29
 
30
30
  # Will be overriden by `conftest.py` with the value from the `--executor`
31
- # and `--scheduler` command-line arguments
31
+ # and `--cluster` command-line arguments
32
32
  DEFAULT_EXECUTOR = "in-memory"
33
- DEFAULT_SCHEDULER = "synchronous"
33
+ DEFAULT_RUNTIME = "tasks"
34
+ DEFAULT_CLUSTER = "single"
34
35
  DEFAULT_BLOCKSIZE_MODE: Literal["small", "default"] = "default"
35
36
 
36
37
 
@@ -111,8 +112,8 @@ def assert_gpu_result_equal(
111
112
 
112
113
  # These keywords are correct, but mypy doesn't see that.
113
114
  # the 'misc' is for 'error: Keywords must be strings'
114
- expect = lazydf.collect(**final_polars_collect_kwargs) # type: ignore[call-overload,misc]
115
- got = lazydf.collect(**final_cudf_collect_kwargs, engine=engine) # type: ignore[call-overload,misc]
115
+ expect = lazydf.collect(**final_polars_collect_kwargs) # type: ignore[misc, call-overload]
116
+ got = lazydf.collect(**final_cudf_collect_kwargs, engine=engine) # type: ignore[misc, call-overload]
116
117
 
117
118
  assert_kwargs_bool: dict[str, bool] = {
118
119
  "check_row_order": check_row_order,
@@ -128,11 +129,14 @@ def assert_gpu_result_equal(
128
129
  else:
129
130
  tol_kwargs = {"rel_tol": rtol, "abs_tol": atol}
130
131
 
132
+ # the type checker errors with:
133
+ # Argument 4 to "assert_frame_equal" has incompatible type "**dict[str, float]"; expected "bool" [arg-type]
134
+ # which seems to be a bug in the type checker / type annotations.
131
135
  assert_frame_equal(
132
136
  expect,
133
137
  got,
134
138
  **assert_kwargs_bool,
135
- **tol_kwargs,
139
+ **tol_kwargs, # type: ignore[arg-type]
136
140
  )
137
141
 
138
142
 
@@ -202,7 +206,8 @@ def get_default_engine(
202
206
  executor_options: dict[str, Any] = {}
203
207
  executor = executor or DEFAULT_EXECUTOR
204
208
  if executor == "streaming":
205
- executor_options["scheduler"] = DEFAULT_SCHEDULER
209
+ executor_options["cluster"] = DEFAULT_CLUSTER
210
+ executor_options["runtime"] = DEFAULT_RUNTIME
206
211
 
207
212
  blocksize_mode = blocksize_mode or DEFAULT_BLOCKSIZE_MODE
208
213
 
@@ -289,7 +294,7 @@ def assert_collect_raises(
289
294
  )
290
295
 
291
296
  try:
292
- lazydf.collect(**final_polars_collect_kwargs) # type: ignore[call-overload,misc]
297
+ lazydf.collect(**final_polars_collect_kwargs) # type: ignore[misc, call-overload]
293
298
  except polars_except:
294
299
  pass
295
300
  except Exception as e:
@@ -302,7 +307,7 @@ def assert_collect_raises(
302
307
 
303
308
  engine = GPUEngine(raise_on_fail=True)
304
309
  try:
305
- lazydf.collect(**final_cudf_collect_kwargs, engine=engine) # type: ignore[call-overload,misc]
310
+ lazydf.collect(**final_cudf_collect_kwargs, engine=engine) # type: ignore[misc, call-overload]
306
311
  except cudf_except:
307
312
  pass
308
313
  except Exception as e:
cudf_polars/testing/io.py CHANGED
@@ -11,6 +11,7 @@ from typing import TYPE_CHECKING
11
11
  import polars as pl
12
12
 
13
13
  if TYPE_CHECKING:
14
+ from collections.abc import Callable
14
15
  from typing import Literal
15
16
 
16
17
  __all__: list[str] = ["make_partitioned_source"]
@@ -110,7 +111,7 @@ def make_lazy_frame(
110
111
  assert path is not None, f"path is required for fmt={fmt}."
111
112
  row_group_size: int | None = None
112
113
  if fmt == "parquet":
113
- read = pl.scan_parquet
114
+ read: Callable[..., pl.LazyFrame] = pl.scan_parquet
114
115
  row_group_size = 10
115
116
  elif fmt == "csv":
116
117
  read = pl.scan_csv