cudf-polars-cu13 25.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. cudf_polars/GIT_COMMIT +1 -0
  2. cudf_polars/VERSION +1 -0
  3. cudf_polars/__init__.py +28 -0
  4. cudf_polars/_version.py +21 -0
  5. cudf_polars/callback.py +318 -0
  6. cudf_polars/containers/__init__.py +13 -0
  7. cudf_polars/containers/column.py +495 -0
  8. cudf_polars/containers/dataframe.py +361 -0
  9. cudf_polars/containers/datatype.py +137 -0
  10. cudf_polars/dsl/__init__.py +8 -0
  11. cudf_polars/dsl/expr.py +66 -0
  12. cudf_polars/dsl/expressions/__init__.py +8 -0
  13. cudf_polars/dsl/expressions/aggregation.py +226 -0
  14. cudf_polars/dsl/expressions/base.py +272 -0
  15. cudf_polars/dsl/expressions/binaryop.py +120 -0
  16. cudf_polars/dsl/expressions/boolean.py +326 -0
  17. cudf_polars/dsl/expressions/datetime.py +271 -0
  18. cudf_polars/dsl/expressions/literal.py +97 -0
  19. cudf_polars/dsl/expressions/rolling.py +643 -0
  20. cudf_polars/dsl/expressions/selection.py +74 -0
  21. cudf_polars/dsl/expressions/slicing.py +46 -0
  22. cudf_polars/dsl/expressions/sorting.py +85 -0
  23. cudf_polars/dsl/expressions/string.py +1002 -0
  24. cudf_polars/dsl/expressions/struct.py +137 -0
  25. cudf_polars/dsl/expressions/ternary.py +49 -0
  26. cudf_polars/dsl/expressions/unary.py +517 -0
  27. cudf_polars/dsl/ir.py +2607 -0
  28. cudf_polars/dsl/nodebase.py +164 -0
  29. cudf_polars/dsl/to_ast.py +359 -0
  30. cudf_polars/dsl/tracing.py +16 -0
  31. cudf_polars/dsl/translate.py +939 -0
  32. cudf_polars/dsl/traversal.py +224 -0
  33. cudf_polars/dsl/utils/__init__.py +8 -0
  34. cudf_polars/dsl/utils/aggregations.py +481 -0
  35. cudf_polars/dsl/utils/groupby.py +98 -0
  36. cudf_polars/dsl/utils/naming.py +34 -0
  37. cudf_polars/dsl/utils/replace.py +61 -0
  38. cudf_polars/dsl/utils/reshape.py +74 -0
  39. cudf_polars/dsl/utils/rolling.py +121 -0
  40. cudf_polars/dsl/utils/windows.py +192 -0
  41. cudf_polars/experimental/__init__.py +8 -0
  42. cudf_polars/experimental/base.py +386 -0
  43. cudf_polars/experimental/benchmarks/__init__.py +4 -0
  44. cudf_polars/experimental/benchmarks/pdsds.py +220 -0
  45. cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
  46. cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
  47. cudf_polars/experimental/benchmarks/pdsds_queries/q10.py +225 -0
  48. cudf_polars/experimental/benchmarks/pdsds_queries/q2.py +244 -0
  49. cudf_polars/experimental/benchmarks/pdsds_queries/q3.py +65 -0
  50. cudf_polars/experimental/benchmarks/pdsds_queries/q4.py +359 -0
  51. cudf_polars/experimental/benchmarks/pdsds_queries/q5.py +462 -0
  52. cudf_polars/experimental/benchmarks/pdsds_queries/q6.py +92 -0
  53. cudf_polars/experimental/benchmarks/pdsds_queries/q7.py +79 -0
  54. cudf_polars/experimental/benchmarks/pdsds_queries/q8.py +524 -0
  55. cudf_polars/experimental/benchmarks/pdsds_queries/q9.py +137 -0
  56. cudf_polars/experimental/benchmarks/pdsh.py +814 -0
  57. cudf_polars/experimental/benchmarks/utils.py +832 -0
  58. cudf_polars/experimental/dask_registers.py +200 -0
  59. cudf_polars/experimental/dispatch.py +156 -0
  60. cudf_polars/experimental/distinct.py +197 -0
  61. cudf_polars/experimental/explain.py +157 -0
  62. cudf_polars/experimental/expressions.py +590 -0
  63. cudf_polars/experimental/groupby.py +327 -0
  64. cudf_polars/experimental/io.py +943 -0
  65. cudf_polars/experimental/join.py +391 -0
  66. cudf_polars/experimental/parallel.py +423 -0
  67. cudf_polars/experimental/repartition.py +69 -0
  68. cudf_polars/experimental/scheduler.py +155 -0
  69. cudf_polars/experimental/select.py +188 -0
  70. cudf_polars/experimental/shuffle.py +354 -0
  71. cudf_polars/experimental/sort.py +609 -0
  72. cudf_polars/experimental/spilling.py +151 -0
  73. cudf_polars/experimental/statistics.py +795 -0
  74. cudf_polars/experimental/utils.py +169 -0
  75. cudf_polars/py.typed +0 -0
  76. cudf_polars/testing/__init__.py +8 -0
  77. cudf_polars/testing/asserts.py +448 -0
  78. cudf_polars/testing/io.py +122 -0
  79. cudf_polars/testing/plugin.py +236 -0
  80. cudf_polars/typing/__init__.py +219 -0
  81. cudf_polars/utils/__init__.py +8 -0
  82. cudf_polars/utils/config.py +741 -0
  83. cudf_polars/utils/conversion.py +40 -0
  84. cudf_polars/utils/dtypes.py +118 -0
  85. cudf_polars/utils/sorting.py +53 -0
  86. cudf_polars/utils/timer.py +39 -0
  87. cudf_polars/utils/versions.py +27 -0
  88. cudf_polars_cu13-25.10.0.dist-info/METADATA +136 -0
  89. cudf_polars_cu13-25.10.0.dist-info/RECORD +92 -0
  90. cudf_polars_cu13-25.10.0.dist-info/WHEEL +5 -0
  91. cudf_polars_cu13-25.10.0.dist-info/licenses/LICENSE +201 -0
  92. cudf_polars_cu13-25.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,609 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Sorting Logic."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import TYPE_CHECKING, Any, Literal, TypedDict
8
+
9
+ import polars as pl
10
+
11
+ import pylibcudf as plc
12
+ from rmm.pylibrmm.stream import DEFAULT_STREAM
13
+
14
+ from cudf_polars.containers import Column, DataFrame, DataType
15
+ from cudf_polars.dsl.expr import Col
16
+ from cudf_polars.dsl.ir import IR, Sort
17
+ from cudf_polars.dsl.traversal import traversal
18
+ from cudf_polars.dsl.utils.naming import unique_names
19
+ from cudf_polars.experimental.base import PartitionInfo, get_key_name
20
+ from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node
21
+ from cudf_polars.experimental.repartition import Repartition
22
+ from cudf_polars.experimental.shuffle import _simple_shuffle_graph
23
+ from cudf_polars.experimental.utils import _concat, _fallback_inform, _lower_ir_fallback
24
+ from cudf_polars.utils.config import ShuffleMethod
25
+
26
+ if TYPE_CHECKING:
27
+ from collections.abc import MutableMapping, Sequence
28
+
29
+ from cudf_polars.dsl.expr import NamedExpr
30
+ from cudf_polars.experimental.dispatch import LowerIRTransformer
31
+ from cudf_polars.typing import Schema
32
+
33
+
34
+ def find_sort_splits(
35
+ tbl: plc.Table,
36
+ sort_boundaries: plc.Table,
37
+ my_part_id: int,
38
+ column_order: Sequence[plc.types.Order],
39
+ null_order: Sequence[plc.types.NullOrder],
40
+ ) -> list[int]:
41
+ """
42
+ Find local sort splits given all (global) split candidates.
43
+
44
+ The reason for much of the complexity is to get the result sizes as
45
+ precise as possible even when e.g. all values are equal.
46
+ In other words, this goes through extra effort to split the data at the
47
+ precise boundaries (which includes part_id and local_row_number).
48
+
49
+ Parameters
50
+ ----------
51
+ tbl
52
+ Locally sorted table only containing sort columns.
53
+ sort_boundaries
54
+ Sorted table containing the global sort/split boundaries. Compared to `tbl`
55
+ must contain additional partition_id and local_row_number columns.
56
+ my_part_id
57
+ The partition id of the local node (as the `split_candidates` column).
58
+ column_order
59
+ The order in which tbl is sorted.
60
+ null_order
61
+ The null order in which tbl is sorted.
62
+
63
+ Returns
64
+ -------
65
+ The split points for the local partition.
66
+ """
67
+ column_order = list(column_order)
68
+ null_order = list(null_order)
69
+
70
+ # We now need to find the local split points. To do this, first split out
71
+ # the partition id and the local row number of the final split values
72
+ *sort_boundaries, split_part_id, split_local_row = sort_boundaries.columns()
73
+ sort_boundaries = plc.Table(sort_boundaries)
74
+ # Now we find the first and last row in the local table corresponding to the split value
75
+ # (first and last, because there may be multiple rows with the same split value)
76
+ split_first_col = plc.search.lower_bound(
77
+ tbl, sort_boundaries, column_order, null_order
78
+ )
79
+ split_last_col = plc.search.upper_bound(
80
+ tbl, sort_boundaries, column_order, null_order
81
+ )
82
+ # And convert to list for final processing
83
+ split_first_col = pl.Series(split_first_col).to_list()
84
+ split_last_col = pl.Series(split_last_col).to_list()
85
+ split_part_id = pl.Series(split_part_id).to_list()
86
+ split_local_row = pl.Series(split_local_row).to_list()
87
+
88
+ # Find the final split points. This is slightly tricky because of the possibility
89
+ # of equal values, which is why we need the part_id and local_row.
90
+ # Consider for example the case when all data is equal.
91
+ split_points = []
92
+ for first, last, part_id, local_row in zip(
93
+ split_first_col, split_last_col, split_part_id, split_local_row, strict=False
94
+ ):
95
+ if part_id < my_part_id:
96
+ # Local data is globally later so split at first valid row.
97
+ split_points.append(first)
98
+ elif part_id > my_part_id:
99
+ # Local data is globally earlier so split after last valid row.
100
+ split_points.append(last)
101
+ else:
102
+ # The split point is within our chunk, so use original local row
103
+ split_points.append(local_row)
104
+
105
+ return split_points
106
+
107
+
108
+ def _select_local_split_candidates(
109
+ df: DataFrame,
110
+ by: Sequence[str],
111
+ num_partitions: int,
112
+ my_part_id: int,
113
+ ) -> DataFrame:
114
+ """
115
+ Create a graph that selects the local sort boundaries for a partition.
116
+
117
+ Returns a pylibcudf table with the local sort boundaries (including part and
118
+ row id columns). The columns are already in the order of `by`.
119
+ """
120
+ df = df.select(by)
121
+ name_gen = unique_names(df.column_names)
122
+ part_id_dtype = DataType(pl.UInt32())
123
+ if df.num_rows == 0:
124
+ # Return empty DataFrame with the correct column names and dtypes
125
+ return DataFrame(
126
+ [
127
+ *df.columns,
128
+ Column(
129
+ plc.column_factories.make_empty_column(part_id_dtype.plc),
130
+ dtype=part_id_dtype,
131
+ name=next(name_gen),
132
+ ),
133
+ Column(
134
+ plc.column_factories.make_empty_column(part_id_dtype.plc),
135
+ dtype=part_id_dtype,
136
+ name=next(name_gen),
137
+ ),
138
+ ]
139
+ )
140
+
141
+ candidates = [i * df.num_rows // num_partitions for i in range(num_partitions)]
142
+ row_id = plc.Column.from_iterable_of_py(candidates, part_id_dtype.plc)
143
+
144
+ res = plc.copying.gather(df.table, row_id, plc.copying.OutOfBoundsPolicy.DONT_CHECK)
145
+ part_id = plc.Column.from_scalar(
146
+ plc.Scalar.from_py(my_part_id, part_id_dtype.plc),
147
+ len(candidates),
148
+ )
149
+
150
+ return DataFrame.from_table(
151
+ plc.Table([*res.columns(), part_id, row_id]),
152
+ [*df.column_names, next(name_gen), next(name_gen)],
153
+ [*df.dtypes, part_id_dtype, part_id_dtype],
154
+ )
155
+
156
+
157
+ def _get_final_sort_boundaries(
158
+ sort_boundaries_candidates: DataFrame,
159
+ column_order: Sequence[plc.types.Order],
160
+ null_order: Sequence[plc.types.NullOrder],
161
+ num_partitions: int,
162
+ ) -> plc.Table:
163
+ """
164
+ Find the global sort split boundaries from all gathered split candidates.
165
+
166
+ Parameters
167
+ ----------
168
+ sort_boundaries_candidates
169
+ All gathered split candidates.
170
+ column_order
171
+ The order in which the split candidates are sorted.
172
+ null_order
173
+ The null order in which the split candidates are sorted.
174
+ num_partitions
175
+ The number of partitions to split the data into.
176
+
177
+ """
178
+ column_order = list(column_order)
179
+ null_order = list(null_order)
180
+
181
+ # The global split candidates need to be stable sorted to find the correct
182
+ # final split points.
183
+ # NOTE: This could be a merge if done earlier (but it should be small data).
184
+ sorted_candidates = plc.sorting.sort(
185
+ sort_boundaries_candidates.table,
186
+ # split candidates has the additional partition_id and row_number columns
187
+ column_order + [plc.types.Order.ASCENDING] * 2,
188
+ null_order + [plc.types.NullOrder.AFTER] * 2,
189
+ )
190
+ selected_candidates = plc.Column.from_iterable_of_py(
191
+ [
192
+ i * sorted_candidates.num_rows() // num_partitions
193
+ for i in range(1, num_partitions)
194
+ ]
195
+ )
196
+ # Get the actual values at which we will split the data
197
+ sort_boundaries = plc.copying.gather(
198
+ sorted_candidates, selected_candidates, plc.copying.OutOfBoundsPolicy.DONT_CHECK
199
+ )
200
+
201
+ return DataFrame.from_table(
202
+ sort_boundaries,
203
+ sort_boundaries_candidates.column_names,
204
+ sort_boundaries_candidates.dtypes,
205
+ )
206
+
207
+
208
+ def _sort_boundaries_graph(
209
+ name_in: str,
210
+ by: Sequence[str],
211
+ column_order: Sequence[plc.types.Order],
212
+ null_order: Sequence[plc.types.NullOrder],
213
+ count: int,
214
+ ) -> tuple[str, MutableMapping[Any, Any]]:
215
+ """Graph to get the boundaries from all partitions."""
216
+ local_boundaries_name = f"sort-boundaries_local-{name_in}"
217
+ concat_boundaries_name = f"sort-boundaries-concat-{name_in}"
218
+ global_boundaries_name = f"sort-boundaries-{name_in}"
219
+ graph: MutableMapping[Any, Any] = {}
220
+
221
+ _concat_list = []
222
+ for part_id in range(count):
223
+ graph[(local_boundaries_name, part_id)] = (
224
+ _select_local_split_candidates,
225
+ (name_in, part_id),
226
+ by,
227
+ count,
228
+ part_id,
229
+ )
230
+ _concat_list.append((local_boundaries_name, part_id))
231
+
232
+ graph[concat_boundaries_name] = (_concat, *_concat_list)
233
+ graph[global_boundaries_name] = (
234
+ _get_final_sort_boundaries,
235
+ concat_boundaries_name,
236
+ column_order,
237
+ null_order,
238
+ count,
239
+ )
240
+ return global_boundaries_name, graph
241
+
242
+
243
+ class SortedShuffleOptions(TypedDict):
244
+ """RapidsMPF shuffling options."""
245
+
246
+ by: Sequence[str]
247
+ order: Sequence[plc.types.Order]
248
+ null_order: Sequence[plc.types.NullOrder]
249
+ column_names: Sequence[str]
250
+ column_dtypes: Sequence[DataType]
251
+ cluster_kind: Literal["dask", "single"]
252
+
253
+
254
+ # Experimental rapidsmpf shuffler integration
255
+ class RMPFIntegrationSortedShuffle: # pragma: no cover
256
+ """cuDF-Polars protocol for rapidsmpf shuffler."""
257
+
258
+ @staticmethod
259
+ def insert_partition(
260
+ df: DataFrame,
261
+ partition_id: int,
262
+ partition_count: int,
263
+ shuffler: Any,
264
+ options: SortedShuffleOptions,
265
+ sort_boundaries: DataFrame,
266
+ ) -> None:
267
+ """Add cudf-polars DataFrame chunks to an RMP shuffler."""
268
+ from rapidsmpf.integrations.cudf.partition import split_and_pack
269
+
270
+ if options["cluster_kind"] == "dask":
271
+ from rapidsmpf.integrations.dask import get_worker_context
272
+
273
+ else:
274
+ from rapidsmpf.integrations.single import get_worker_context
275
+
276
+ context = get_worker_context()
277
+
278
+ by = options["by"]
279
+
280
+ splits = find_sort_splits(
281
+ df.select(by).table,
282
+ sort_boundaries.table,
283
+ partition_id,
284
+ options["order"],
285
+ options["null_order"],
286
+ )
287
+ packed_inputs = split_and_pack(
288
+ df.table,
289
+ splits=splits,
290
+ br=context.br,
291
+ stream=DEFAULT_STREAM,
292
+ )
293
+ shuffler.insert_chunks(packed_inputs)
294
+
295
+ @staticmethod
296
+ def extract_partition(
297
+ partition_id: int,
298
+ shuffler: Any,
299
+ options: SortedShuffleOptions,
300
+ ) -> DataFrame:
301
+ """Extract a finished partition from the RMP shuffler."""
302
+ from rapidsmpf.integrations.cudf.partition import (
303
+ unpack_and_concat,
304
+ unspill_partitions,
305
+ )
306
+
307
+ if options["cluster_kind"] == "dask":
308
+ from rapidsmpf.integrations.dask import get_worker_context
309
+
310
+ else:
311
+ from rapidsmpf.integrations.single import get_worker_context
312
+
313
+ context = get_worker_context()
314
+
315
+ shuffler.wait_on(partition_id)
316
+ column_names = options["column_names"]
317
+ column_dtypes = options["column_dtypes"]
318
+
319
+ # TODO: When sorting, this step should finalize with a merge (unless we
320
+ # require stability, as cudf merge is not stable).
321
+ return DataFrame.from_table(
322
+ unpack_and_concat(
323
+ unspill_partitions(
324
+ shuffler.extract(partition_id),
325
+ br=context.br,
326
+ allow_overbooking=True,
327
+ statistics=context.statistics,
328
+ ),
329
+ br=context.br,
330
+ stream=DEFAULT_STREAM,
331
+ ),
332
+ column_names,
333
+ column_dtypes,
334
+ )
335
+
336
+
337
+ def _sort_partition_dataframe(
338
+ df: DataFrame,
339
+ partition_id: int, # Not currently used
340
+ partition_count: int,
341
+ options: MutableMapping[str, Any],
342
+ sort_boundaries: DataFrame,
343
+ ) -> MutableMapping[int, DataFrame]:
344
+ """
345
+ Partition a sorted DataFrame for shuffling.
346
+
347
+ Parameters
348
+ ----------
349
+ df
350
+ The DataFrame to partition.
351
+ partition_id
352
+ The partition id of the current partition.
353
+ partition_count
354
+ The total number of partitions.
355
+ options
356
+ The sort options ``(by, order, null_order)``.
357
+ sort_boundaries
358
+ The global sort boundary candidates used to decide where to split.
359
+ """
360
+ if df.num_rows == 0: # pragma: no cover
361
+ # Fast path for empty DataFrame
362
+ return {i: df for i in range(partition_count)}
363
+
364
+ splits = find_sort_splits(
365
+ df.select(options["by"]).table,
366
+ sort_boundaries.table,
367
+ partition_id,
368
+ options["order"],
369
+ options["null_order"],
370
+ )
371
+
372
+ # Split and return the partitioned result
373
+ return {
374
+ i: DataFrame.from_table(
375
+ split,
376
+ df.column_names,
377
+ df.dtypes,
378
+ )
379
+ for i, split in enumerate(plc.copying.split(df.table, splits))
380
+ }
381
+
382
+
383
+ class ShuffleSorted(IR):
384
+ """
385
+ Shuffle already locally sorted multi-partition data.
386
+
387
+ Shuffling is performed by extracting sort boundary candidates from all partitions,
388
+ sharing them all-to-all and then exchanging data accordingly.
389
+ The sorting information is required to be passed in identically to the already
390
+ performed local sort and as of now the final result needs to be sorted again to
391
+ merge the partitions.
392
+ """
393
+
394
+ __slots__ = ("by", "null_order", "order", "shuffle_method")
395
+ _non_child = ("schema", "by", "order", "null_order", "shuffle_method")
396
+ by: tuple[NamedExpr, ...]
397
+ """Keys by which the data was sorted."""
398
+ order: tuple[plc.types.Order, ...]
399
+ """Sort order if sorted."""
400
+ null_order: tuple[plc.types.NullOrder, ...]
401
+ """Null precedence if sorted."""
402
+ shuffle_method: ShuffleMethod
403
+ """Shuffle method to use."""
404
+
405
+ def __init__(
406
+ self,
407
+ schema: Schema,
408
+ by: tuple[NamedExpr, ...],
409
+ order: tuple[plc.types.Order, ...],
410
+ null_order: tuple[plc.types.NullOrder, ...],
411
+ shuffle_method: ShuffleMethod,
412
+ df: IR,
413
+ ):
414
+ self.schema = schema
415
+ self.by = by
416
+ self.order = order
417
+ self.null_order = null_order
418
+ self.shuffle_method = shuffle_method
419
+ self._non_child_args = (schema, by, order, null_order, shuffle_method)
420
+ self.children = (df,)
421
+
422
+ @classmethod
423
+ def do_evaluate(
424
+ cls,
425
+ schema: Schema,
426
+ by: tuple[NamedExpr, ...],
427
+ order: tuple[plc.types.Order, ...],
428
+ null_order: tuple[plc.types.NullOrder, ...],
429
+ shuffle_method: ShuffleMethod,
430
+ df: DataFrame,
431
+ ) -> DataFrame: # pragma: no cover
432
+ """Evaluate and return a dataframe."""
433
+ # Single-partition ShuffleSorted evaluation is a no-op
434
+ return df
435
+
436
+
437
+ @lower_ir_node.register(Sort)
438
+ def _(
439
+ ir: Sort, rec: LowerIRTransformer
440
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
441
+ # Special handling for slicing
442
+ # (May be a top- or bottom-k operation)
443
+
444
+ if ir.zlice is not None:
445
+ # TODO: Handle large slices (e.g. 1m+ rows), this should go into the branch
446
+ # below, but will require additional logic there.
447
+
448
+ # Check if zlice has an offset, i.e. includes the start or reaches the end.
449
+ # If an offset exists it would be incorrect to apply in the first pwise sort.
450
+ has_offset = ir.zlice[0] > 0 or (
451
+ ir.zlice[0] < 0
452
+ and ir.zlice[1] is not None
453
+ and ir.zlice[0] + ir.zlice[1] < 0
454
+ )
455
+ if has_offset:
456
+ return _lower_ir_fallback(
457
+ ir,
458
+ rec,
459
+ msg="Sort does not support a multi-partition slice with an offset.",
460
+ )
461
+
462
+ from cudf_polars.experimental.parallel import _lower_ir_pwise
463
+
464
+ # Sort input partitions
465
+ new_node, partition_info = _lower_ir_pwise(ir, rec)
466
+ if partition_info[new_node].count > 1:
467
+ # Collapse down to single partition
468
+ inter = Repartition(new_node.schema, new_node)
469
+ partition_info[inter] = PartitionInfo(count=1)
470
+ # Sort reduced partition
471
+ new_node = ir.reconstruct([inter])
472
+ partition_info[new_node] = PartitionInfo(count=1)
473
+ return new_node, partition_info
474
+
475
+ # Check sort keys
476
+ if not all(
477
+ isinstance(expr, Col) for expr in traversal([e.value for e in ir.by])
478
+ ): # pragma: no cover
479
+ return _lower_ir_fallback(
480
+ ir,
481
+ rec,
482
+ msg="sort currently only supports column names as `by` keys.",
483
+ )
484
+
485
+ # Extract child partitioning
486
+ child, partition_info = rec(ir.children[0])
487
+
488
+ # Extract shuffle method
489
+ config_options = rec.state["config_options"]
490
+ assert config_options.executor.name == "streaming", (
491
+ "'in-memory' executor not supported in 'lower_ir_node'"
492
+ )
493
+ # Avoid rapidsmpf shuffle with maintain_order=True (for now)
494
+ shuffle_method = (
495
+ ShuffleMethod("tasks") if ir.stable else config_options.executor.shuffle_method
496
+ )
497
+ if (
498
+ shuffle_method != config_options.executor.shuffle_method
499
+ ): # pragma: no cover; Requires rapidsmpf
500
+ _fallback_inform(
501
+ f"shuffle_method={shuffle_method} does not support maintain_order=True. "
502
+ "Falling back to shuffle_method='tasks'.",
503
+ config_options,
504
+ )
505
+
506
+ # Handle single-partition case
507
+ if partition_info[child].count == 1:
508
+ single_part_node = ir.reconstruct([child])
509
+ partition_info[single_part_node] = partition_info[child]
510
+ return single_part_node, partition_info
511
+
512
+ local_sort_node = ir.reconstruct([child])
513
+ partition_info[local_sort_node] = partition_info[child]
514
+
515
+ shuffle = ShuffleSorted(
516
+ ir.schema,
517
+ ir.by,
518
+ ir.order,
519
+ ir.null_order,
520
+ shuffle_method,
521
+ local_sort_node,
522
+ )
523
+ partition_info[shuffle] = partition_info[child]
524
+
525
+ # We sort again locally.
526
+ assert ir.zlice is None # zlice handling would be incorrect without adjustment
527
+ final_sort_node = ir.reconstruct([shuffle])
528
+ partition_info[final_sort_node] = partition_info[shuffle]
529
+
530
+ return final_sort_node, partition_info
531
+
532
+
533
+ @generate_ir_tasks.register(ShuffleSorted)
534
+ def _(
535
+ ir: ShuffleSorted, partition_info: MutableMapping[IR, PartitionInfo]
536
+ ) -> MutableMapping[Any, Any]:
537
+ by = [ne.value.name for ne in ir.by if isinstance(ne.value, Col)]
538
+ if len(by) != len(ir.by): # pragma: no cover
539
+ # We should not reach here as this is checked in the lower_ir_node
540
+ raise NotImplementedError("Sorting columns must be column names.")
541
+
542
+ (child,) = ir.children
543
+
544
+ sort_boundaries_name, graph = _sort_boundaries_graph(
545
+ get_key_name(child),
546
+ by,
547
+ ir.order,
548
+ ir.null_order,
549
+ partition_info[child].count,
550
+ )
551
+
552
+ options = {
553
+ "by": by,
554
+ "order": ir.order,
555
+ "null_order": ir.null_order,
556
+ "column_names": list(ir.schema.keys()),
557
+ "column_dtypes": list(ir.schema.values()),
558
+ }
559
+
560
+ # Try using rapidsmpf shuffler if we have "simple" shuffle
561
+ # keys, and the "shuffle_method" config is set to "rapidsmpf"
562
+ shuffle_method = ir.shuffle_method
563
+ if shuffle_method in ("rapidsmpf", "rapidsmpf-single"): # pragma: no cover
564
+ try:
565
+ if shuffle_method == "rapidsmpf-single":
566
+ from rapidsmpf.integrations.single import rapidsmpf_shuffle_graph
567
+
568
+ options["cluster_kind"] = "single"
569
+ else:
570
+ from rapidsmpf.integrations.dask import rapidsmpf_shuffle_graph
571
+
572
+ options["cluster_kind"] = "dask"
573
+ graph.update(
574
+ rapidsmpf_shuffle_graph(
575
+ get_key_name(child),
576
+ get_key_name(ir),
577
+ partition_info[child].count,
578
+ partition_info[ir].count,
579
+ RMPFIntegrationSortedShuffle,
580
+ options,
581
+ sort_boundaries_name,
582
+ )
583
+ )
584
+ except (ImportError, ValueError) as err:
585
+ # ImportError: rapidsmpf is not installed
586
+ # ValueError: rapidsmpf couldn't find a distributed client
587
+ if shuffle_method == "rapidsmpf": # pragma: no cover
588
+ # Only raise an error if the user specifically
589
+ # set the shuffle method to "rapidsmpf"
590
+ raise ValueError(
591
+ "Rapidsmpf is not installed correctly or the current "
592
+ "Dask cluster does not support rapidsmpf shuffling."
593
+ ) from err
594
+ else:
595
+ return graph
596
+
597
+ # Simple task-based fall-back
598
+ graph.update(
599
+ _simple_shuffle_graph(
600
+ get_key_name(child),
601
+ get_key_name(ir),
602
+ partition_info[child].count,
603
+ partition_info[ir].count,
604
+ _sort_partition_dataframe,
605
+ options,
606
+ sort_boundaries_name,
607
+ )
608
+ )
609
+ return graph