cudf-polars-cu13 25.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. cudf_polars/GIT_COMMIT +1 -0
  2. cudf_polars/VERSION +1 -0
  3. cudf_polars/__init__.py +28 -0
  4. cudf_polars/_version.py +21 -0
  5. cudf_polars/callback.py +318 -0
  6. cudf_polars/containers/__init__.py +13 -0
  7. cudf_polars/containers/column.py +495 -0
  8. cudf_polars/containers/dataframe.py +361 -0
  9. cudf_polars/containers/datatype.py +137 -0
  10. cudf_polars/dsl/__init__.py +8 -0
  11. cudf_polars/dsl/expr.py +66 -0
  12. cudf_polars/dsl/expressions/__init__.py +8 -0
  13. cudf_polars/dsl/expressions/aggregation.py +226 -0
  14. cudf_polars/dsl/expressions/base.py +272 -0
  15. cudf_polars/dsl/expressions/binaryop.py +120 -0
  16. cudf_polars/dsl/expressions/boolean.py +326 -0
  17. cudf_polars/dsl/expressions/datetime.py +271 -0
  18. cudf_polars/dsl/expressions/literal.py +97 -0
  19. cudf_polars/dsl/expressions/rolling.py +643 -0
  20. cudf_polars/dsl/expressions/selection.py +74 -0
  21. cudf_polars/dsl/expressions/slicing.py +46 -0
  22. cudf_polars/dsl/expressions/sorting.py +85 -0
  23. cudf_polars/dsl/expressions/string.py +1002 -0
  24. cudf_polars/dsl/expressions/struct.py +137 -0
  25. cudf_polars/dsl/expressions/ternary.py +49 -0
  26. cudf_polars/dsl/expressions/unary.py +517 -0
  27. cudf_polars/dsl/ir.py +2607 -0
  28. cudf_polars/dsl/nodebase.py +164 -0
  29. cudf_polars/dsl/to_ast.py +359 -0
  30. cudf_polars/dsl/tracing.py +16 -0
  31. cudf_polars/dsl/translate.py +939 -0
  32. cudf_polars/dsl/traversal.py +224 -0
  33. cudf_polars/dsl/utils/__init__.py +8 -0
  34. cudf_polars/dsl/utils/aggregations.py +481 -0
  35. cudf_polars/dsl/utils/groupby.py +98 -0
  36. cudf_polars/dsl/utils/naming.py +34 -0
  37. cudf_polars/dsl/utils/replace.py +61 -0
  38. cudf_polars/dsl/utils/reshape.py +74 -0
  39. cudf_polars/dsl/utils/rolling.py +121 -0
  40. cudf_polars/dsl/utils/windows.py +192 -0
  41. cudf_polars/experimental/__init__.py +8 -0
  42. cudf_polars/experimental/base.py +386 -0
  43. cudf_polars/experimental/benchmarks/__init__.py +4 -0
  44. cudf_polars/experimental/benchmarks/pdsds.py +220 -0
  45. cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
  46. cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
  47. cudf_polars/experimental/benchmarks/pdsds_queries/q10.py +225 -0
  48. cudf_polars/experimental/benchmarks/pdsds_queries/q2.py +244 -0
  49. cudf_polars/experimental/benchmarks/pdsds_queries/q3.py +65 -0
  50. cudf_polars/experimental/benchmarks/pdsds_queries/q4.py +359 -0
  51. cudf_polars/experimental/benchmarks/pdsds_queries/q5.py +462 -0
  52. cudf_polars/experimental/benchmarks/pdsds_queries/q6.py +92 -0
  53. cudf_polars/experimental/benchmarks/pdsds_queries/q7.py +79 -0
  54. cudf_polars/experimental/benchmarks/pdsds_queries/q8.py +524 -0
  55. cudf_polars/experimental/benchmarks/pdsds_queries/q9.py +137 -0
  56. cudf_polars/experimental/benchmarks/pdsh.py +814 -0
  57. cudf_polars/experimental/benchmarks/utils.py +832 -0
  58. cudf_polars/experimental/dask_registers.py +200 -0
  59. cudf_polars/experimental/dispatch.py +156 -0
  60. cudf_polars/experimental/distinct.py +197 -0
  61. cudf_polars/experimental/explain.py +157 -0
  62. cudf_polars/experimental/expressions.py +590 -0
  63. cudf_polars/experimental/groupby.py +327 -0
  64. cudf_polars/experimental/io.py +943 -0
  65. cudf_polars/experimental/join.py +391 -0
  66. cudf_polars/experimental/parallel.py +423 -0
  67. cudf_polars/experimental/repartition.py +69 -0
  68. cudf_polars/experimental/scheduler.py +155 -0
  69. cudf_polars/experimental/select.py +188 -0
  70. cudf_polars/experimental/shuffle.py +354 -0
  71. cudf_polars/experimental/sort.py +609 -0
  72. cudf_polars/experimental/spilling.py +151 -0
  73. cudf_polars/experimental/statistics.py +795 -0
  74. cudf_polars/experimental/utils.py +169 -0
  75. cudf_polars/py.typed +0 -0
  76. cudf_polars/testing/__init__.py +8 -0
  77. cudf_polars/testing/asserts.py +448 -0
  78. cudf_polars/testing/io.py +122 -0
  79. cudf_polars/testing/plugin.py +236 -0
  80. cudf_polars/typing/__init__.py +219 -0
  81. cudf_polars/utils/__init__.py +8 -0
  82. cudf_polars/utils/config.py +741 -0
  83. cudf_polars/utils/conversion.py +40 -0
  84. cudf_polars/utils/dtypes.py +118 -0
  85. cudf_polars/utils/sorting.py +53 -0
  86. cudf_polars/utils/timer.py +39 -0
  87. cudf_polars/utils/versions.py +27 -0
  88. cudf_polars_cu13-25.10.0.dist-info/METADATA +136 -0
  89. cudf_polars_cu13-25.10.0.dist-info/RECORD +92 -0
  90. cudf_polars_cu13-25.10.0.dist-info/WHEEL +5 -0
  91. cudf_polars_cu13-25.10.0.dist-info/licenses/LICENSE +201 -0
  92. cudf_polars_cu13-25.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,188 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Parallel Select Logic."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import TYPE_CHECKING
8
+
9
+ import polars as pl
10
+
11
+ from cudf_polars.dsl import expr
12
+ from cudf_polars.dsl.expr import Col, Len
13
+ from cudf_polars.dsl.ir import Empty, HConcat, Scan, Select, Union
14
+ from cudf_polars.dsl.traversal import traversal
15
+ from cudf_polars.experimental.base import ColumnStat, PartitionInfo
16
+ from cudf_polars.experimental.dispatch import lower_ir_node
17
+ from cudf_polars.experimental.expressions import decompose_expr_graph
18
+ from cudf_polars.experimental.utils import (
19
+ _contains_unsupported_fill_strategy,
20
+ _lower_ir_fallback,
21
+ )
22
+
23
+ if TYPE_CHECKING:
24
+ from collections.abc import MutableMapping
25
+
26
+ from cudf_polars.dsl.ir import IR
27
+ from cudf_polars.experimental.parallel import LowerIRTransformer
28
+ from cudf_polars.experimental.statistics import StatsCollector
29
+ from cudf_polars.utils.config import ConfigOptions
30
+
31
+
32
+ def decompose_select(
33
+ select_ir: Select,
34
+ input_ir: IR,
35
+ partition_info: MutableMapping[IR, PartitionInfo],
36
+ config_options: ConfigOptions,
37
+ stats: StatsCollector,
38
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
39
+ """
40
+ Decompose a multi-partition Select operation.
41
+
42
+ Parameters
43
+ ----------
44
+ select_ir
45
+ The original Select operation to decompose.
46
+ This object has not been reconstructed with
47
+ ``input_ir`` as its child yet.
48
+ input_ir
49
+ The lowered child of ``select_ir``. This object
50
+ will be decomposed into a "partial" selection
51
+ for each element of ``select_ir.exprs``.
52
+ partition_info
53
+ A mapping from all unique IR nodes to the
54
+ associated partitioning information.
55
+ config_options
56
+ GPUEngine configuration options.
57
+ stats
58
+ Statistics collector.
59
+
60
+ Returns
61
+ -------
62
+ new_ir, partition_info
63
+ The rewritten Select node, and a mapping from
64
+ unique nodes in the new graph to associated
65
+ partitioning information.
66
+
67
+ Notes
68
+ -----
69
+ This function uses ``decompose_expr_graph`` to further
70
+ decompose each element of ``select_ir.exprs``.
71
+
72
+ See Also
73
+ --------
74
+ decompose_expr_graph
75
+ """
76
+ # Collect partial selections
77
+ selections = []
78
+ for ne in select_ir.exprs:
79
+ # Decompose this partial expression
80
+ new_ne, partial_input_ir, _partition_info = decompose_expr_graph(
81
+ ne,
82
+ input_ir,
83
+ partition_info,
84
+ config_options,
85
+ stats.row_count.get(select_ir.children[0], ColumnStat[int](None)),
86
+ stats.column_stats.get(select_ir.children[0], {}),
87
+ )
88
+ pi = _partition_info[partial_input_ir]
89
+ partial_input_ir = Select(
90
+ {ne.name: ne.value.dtype},
91
+ [new_ne],
92
+ True, # noqa: FBT003
93
+ partial_input_ir,
94
+ )
95
+ _partition_info[partial_input_ir] = pi
96
+ partition_info.update(_partition_info)
97
+ selections.append(partial_input_ir)
98
+
99
+ # Concatenate partial selections
100
+ new_ir: HConcat | Select
101
+ if len(selections) > 1:
102
+ new_ir = HConcat(
103
+ select_ir.schema,
104
+ True, # noqa: FBT003
105
+ *selections,
106
+ )
107
+ partition_info[new_ir] = PartitionInfo(
108
+ count=max(partition_info[c].count for c in selections)
109
+ )
110
+ else:
111
+ new_ir = selections[0]
112
+
113
+ return new_ir, partition_info
114
+
115
+
116
+ @lower_ir_node.register(Select)
117
+ def _(
118
+ ir: Select, rec: LowerIRTransformer
119
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
120
+ child, partition_info = rec(ir.children[0])
121
+ pi = partition_info[child]
122
+ if pi.count > 1 and _contains_unsupported_fill_strategy(
123
+ [e.value for e in ir.exprs]
124
+ ):
125
+ return _lower_ir_fallback(
126
+ ir.reconstruct([child]),
127
+ rec,
128
+ msg=(
129
+ "fill_null with strategy other than 'zero' or 'one' is not supported "
130
+ "for multiple partitions; falling back to in-memory evaluation."
131
+ ),
132
+ )
133
+ if (
134
+ pi.count == 1
135
+ and Select._is_len_expr(ir.exprs)
136
+ and isinstance(child, Union)
137
+ and len(child.children) == 1
138
+ and isinstance(child.children[0], Scan)
139
+ and child.children[0].predicate is None
140
+ ):
141
+ # Special Case: Fast count.
142
+ scan = child.children[0]
143
+ count = scan.fast_count()
144
+ dtype = ir.exprs[0].value.dtype
145
+
146
+ lit_expr = expr.LiteralColumn(
147
+ dtype, pl.Series(values=[count], dtype=dtype.polars)
148
+ )
149
+ named_expr = expr.NamedExpr(ir.exprs[0].name or "len", lit_expr)
150
+
151
+ new_node = Select(
152
+ {named_expr.name: named_expr.value.dtype},
153
+ [named_expr],
154
+ should_broadcast=True,
155
+ df=child,
156
+ )
157
+ partition_info[new_node] = PartitionInfo(count=1)
158
+ return new_node, partition_info
159
+
160
+ if not any(
161
+ isinstance(expr, (Col, Len)) for expr in traversal([e.value for e in ir.exprs])
162
+ ):
163
+ # Special Case: Selection does not depend on any columns.
164
+ new_node = ir.reconstruct([input_ir := Empty({})])
165
+ partition_info[input_ir] = partition_info[new_node] = PartitionInfo(count=1)
166
+ return new_node, partition_info
167
+
168
+ if pi.count > 1 and not all(
169
+ expr.is_pointwise for expr in traversal([e.value for e in ir.exprs])
170
+ ):
171
+ # Special Case: Multiple partitions with 1+ non-pointwise expressions.
172
+ try:
173
+ # Try decomposing the underlying expressions
174
+ return decompose_select(
175
+ ir,
176
+ child,
177
+ partition_info,
178
+ rec.state["config_options"],
179
+ rec.state["stats"],
180
+ )
181
+ except NotImplementedError:
182
+ return _lower_ir_fallback(
183
+ ir, rec, msg="This selection is not supported for multiple partitions."
184
+ )
185
+
186
+ new_node = ir.reconstruct([child])
187
+ partition_info[new_node] = pi
188
+ return new_node, partition_info
@@ -0,0 +1,354 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Shuffle Logic."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import operator
8
+ from typing import TYPE_CHECKING, Any, Concatenate, Literal, TypeVar, TypedDict
9
+
10
+ import pylibcudf as plc
11
+ from rmm.pylibrmm.stream import DEFAULT_STREAM
12
+
13
+ from cudf_polars.containers import DataFrame
14
+ from cudf_polars.dsl.expr import Col
15
+ from cudf_polars.dsl.ir import IR
16
+ from cudf_polars.dsl.tracing import nvtx_annotate_cudf_polars
17
+ from cudf_polars.experimental.base import get_key_name
18
+ from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node
19
+ from cudf_polars.experimental.utils import _concat
20
+
21
+ if TYPE_CHECKING:
22
+ from collections.abc import Callable, MutableMapping, Sequence
23
+
24
+ from cudf_polars.containers import DataType
25
+ from cudf_polars.dsl.expr import NamedExpr
26
+ from cudf_polars.experimental.dispatch import LowerIRTransformer
27
+ from cudf_polars.experimental.parallel import PartitionInfo
28
+ from cudf_polars.typing import Schema
29
+ from cudf_polars.utils.config import ShuffleMethod
30
+
31
+
32
+ # Supported shuffle methods
33
+ _SHUFFLE_METHODS = ("rapidsmpf", "tasks")
34
+
35
+
36
+ class ShuffleOptions(TypedDict):
37
+ """RapidsMPF shuffling options."""
38
+
39
+ on: Sequence[str]
40
+ column_names: Sequence[str]
41
+ dtypes: Sequence[DataType]
42
+ cluster_kind: Literal["dask", "single"]
43
+
44
+
45
+ # Experimental rapidsmpf shuffler integration
46
+ class RMPFIntegration: # pragma: no cover
47
+ """cuDF-Polars protocol for rapidsmpf shuffler."""
48
+
49
+ @staticmethod
50
+ @nvtx_annotate_cudf_polars(message="RMPFIntegration.insert_partition")
51
+ def insert_partition(
52
+ df: DataFrame,
53
+ partition_id: int, # Not currently used
54
+ partition_count: int,
55
+ shuffler: Any,
56
+ options: ShuffleOptions,
57
+ *other: Any,
58
+ ) -> None:
59
+ """Add cudf-polars DataFrame chunks to an RMP shuffler."""
60
+ from rapidsmpf.integrations.cudf.partition import partition_and_pack
61
+
62
+ if options["cluster_kind"] == "dask":
63
+ from rapidsmpf.integrations.dask import get_worker_context
64
+
65
+ else:
66
+ from rapidsmpf.integrations.single import get_worker_context
67
+
68
+ context = get_worker_context()
69
+
70
+ on = options["on"]
71
+ assert not other, f"Unexpected arguments: {other}"
72
+ columns_to_hash = tuple(df.column_names.index(val) for val in on)
73
+ packed_inputs = partition_and_pack(
74
+ df.table,
75
+ columns_to_hash=columns_to_hash,
76
+ num_partitions=partition_count,
77
+ br=context.br,
78
+ stream=DEFAULT_STREAM,
79
+ )
80
+ shuffler.insert_chunks(packed_inputs)
81
+
82
+ @staticmethod
83
+ @nvtx_annotate_cudf_polars(message="RMPFIntegration.extract_partition")
84
+ def extract_partition(
85
+ partition_id: int,
86
+ shuffler: Any,
87
+ options: ShuffleOptions,
88
+ ) -> DataFrame:
89
+ """Extract a finished partition from the RMP shuffler."""
90
+ from rapidsmpf.integrations.cudf.partition import (
91
+ unpack_and_concat,
92
+ unspill_partitions,
93
+ )
94
+
95
+ if options["cluster_kind"] == "dask":
96
+ from rapidsmpf.integrations.dask import get_worker_context
97
+
98
+ else:
99
+ from rapidsmpf.integrations.single import get_worker_context
100
+
101
+ context = get_worker_context()
102
+
103
+ shuffler.wait_on(partition_id)
104
+ column_names = options["column_names"]
105
+ dtypes = options["dtypes"]
106
+ return DataFrame.from_table(
107
+ unpack_and_concat(
108
+ unspill_partitions(
109
+ shuffler.extract(partition_id),
110
+ br=context.br,
111
+ allow_overbooking=True,
112
+ statistics=context.statistics,
113
+ ),
114
+ br=context.br,
115
+ stream=DEFAULT_STREAM,
116
+ ),
117
+ column_names,
118
+ dtypes,
119
+ )
120
+
121
+
122
+ class Shuffle(IR):
123
+ """
124
+ Shuffle multi-partition data.
125
+
126
+ Notes
127
+ -----
128
+ Only hash-based partitioning is supported (for now). See
129
+ `ShuffleSorted` for sorting-based shuffling.
130
+ """
131
+
132
+ __slots__ = ("keys", "shuffle_method")
133
+ _non_child = ("schema", "keys", "shuffle_method")
134
+ keys: tuple[NamedExpr, ...]
135
+ """Keys to shuffle on."""
136
+ shuffle_method: ShuffleMethod
137
+ """Shuffle method to use."""
138
+
139
+ def __init__(
140
+ self,
141
+ schema: Schema,
142
+ keys: tuple[NamedExpr, ...],
143
+ shuffle_method: ShuffleMethod,
144
+ df: IR,
145
+ ):
146
+ self.schema = schema
147
+ self.keys = keys
148
+ self.shuffle_method = shuffle_method
149
+ self._non_child_args = (schema, keys, shuffle_method)
150
+ self.children = (df,)
151
+
152
+ @classmethod
153
+ def do_evaluate(
154
+ cls,
155
+ schema: Schema,
156
+ keys: tuple[NamedExpr, ...],
157
+ shuffle_method: ShuffleMethod,
158
+ df: DataFrame,
159
+ ) -> DataFrame: # pragma: no cover
160
+ """Evaluate and return a dataframe."""
161
+ # Single-partition Shuffle evaluation is a no-op
162
+ return df
163
+
164
+
165
+ @nvtx_annotate_cudf_polars(message="Shuffle")
166
+ def _hash_partition_dataframe(
167
+ df: DataFrame,
168
+ partition_id: int, # Used only by sorted shuffling
169
+ partition_count: int,
170
+ options: MutableMapping[str, Any] | None, # No options required
171
+ on: tuple[NamedExpr, ...],
172
+ ) -> dict[int, DataFrame]:
173
+ """
174
+ Partition an input DataFrame for hash-based shuffling.
175
+
176
+ Parameters
177
+ ----------
178
+ df
179
+ DataFrame to partition.
180
+ partition_id
181
+ Partition index (unused for hash partitioning).
182
+ partition_count
183
+ Total number of output partitions.
184
+ options
185
+ Options (unused for hash partitioning).
186
+ on
187
+ Expressions used for the hash partitioning.
188
+
189
+ Returns
190
+ -------
191
+ A dictionary mapping between int partition indices and
192
+ DataFrame fragments.
193
+ """
194
+ assert not options, f"Expected no options, got: {options}"
195
+
196
+ if df.num_rows == 0:
197
+ # Fast path for empty DataFrame
198
+ return dict.fromkeys(range(partition_count), df)
199
+
200
+ # Hash the specified keys to calculate the output
201
+ # partition for each row
202
+ partition_map = plc.binaryop.binary_operation(
203
+ plc.hashing.murmurhash3_x86_32(
204
+ DataFrame([expr.evaluate(df) for expr in on]).table
205
+ ),
206
+ plc.Scalar.from_py(partition_count, plc.DataType(plc.TypeId.UINT32)),
207
+ plc.binaryop.BinaryOperator.PYMOD,
208
+ plc.types.DataType(plc.types.TypeId.UINT32),
209
+ )
210
+
211
+ # Apply partitioning
212
+ t, offsets = plc.partitioning.partition(
213
+ df.table,
214
+ partition_map,
215
+ partition_count,
216
+ )
217
+ splits = offsets[1:-1]
218
+
219
+ # Split and return the partitioned result
220
+ return {
221
+ i: DataFrame.from_table(
222
+ split,
223
+ df.column_names,
224
+ df.dtypes,
225
+ )
226
+ for i, split in enumerate(plc.copying.split(t, splits))
227
+ }
228
+
229
+
230
+ # When dropping Python 3.10, can use _simple_shuffle_graph[OPT_T](...)
231
+ OPT_T = TypeVar("OPT_T")
232
+
233
+
234
+ def _simple_shuffle_graph(
235
+ name_in: str,
236
+ name_out: str,
237
+ count_in: int,
238
+ count_out: int,
239
+ _partition_dataframe_func: Callable[
240
+ Concatenate[DataFrame, int, int, OPT_T, ...],
241
+ MutableMapping[int, DataFrame],
242
+ ],
243
+ options: OPT_T,
244
+ *other: Any,
245
+ ) -> MutableMapping[Any, Any]:
246
+ """Make a simple all-to-all shuffle graph."""
247
+ split_name = f"split-{name_out}"
248
+ inter_name = f"inter-{name_out}"
249
+
250
+ graph: MutableMapping[Any, Any] = {}
251
+ for part_out in range(count_out):
252
+ _concat_list = []
253
+ for part_in in range(count_in):
254
+ graph[(split_name, part_in)] = (
255
+ _partition_dataframe_func,
256
+ (name_in, part_in),
257
+ part_in,
258
+ count_out,
259
+ options,
260
+ *other,
261
+ )
262
+ _concat_list.append((inter_name, part_out, part_in))
263
+ graph[_concat_list[-1]] = (
264
+ operator.getitem,
265
+ (split_name, part_in),
266
+ part_out,
267
+ )
268
+ graph[(name_out, part_out)] = (_concat, *_concat_list)
269
+ return graph
270
+
271
+
272
+ @lower_ir_node.register(Shuffle)
273
+ def _(
274
+ ir: Shuffle, rec: LowerIRTransformer
275
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
276
+ # Simple lower_ir_node handling for the default hash-based shuffle.
277
+ # More-complex logic (e.g. joining and sorting) should
278
+ # be handled separately.
279
+ from cudf_polars.experimental.parallel import PartitionInfo
280
+
281
+ (child,) = ir.children
282
+
283
+ new_child, pi = rec(child)
284
+ if pi[new_child].count == 1 or ir.keys == pi[new_child].partitioned_on:
285
+ # Already shuffled
286
+ return new_child, pi
287
+ new_node = ir.reconstruct([new_child])
288
+ pi[new_node] = PartitionInfo(
289
+ # Default shuffle preserves partition count
290
+ count=pi[new_child].count,
291
+ # Add partitioned_on info
292
+ partitioned_on=ir.keys,
293
+ )
294
+ return new_node, pi
295
+
296
+
297
+ @generate_ir_tasks.register(Shuffle)
298
+ def _(
299
+ ir: Shuffle, partition_info: MutableMapping[IR, PartitionInfo]
300
+ ) -> MutableMapping[Any, Any]:
301
+ # Extract "shuffle_method" configuration
302
+ shuffle_method = ir.shuffle_method
303
+
304
+ # Try using rapidsmpf shuffler if we have "simple" shuffle
305
+ # keys, and the "shuffle_method" config is set to "rapidsmpf"
306
+ _keys: list[Col]
307
+ if shuffle_method in ("rapidsmpf", "rapidsmpf-single") and len(
308
+ _keys := [ne.value for ne in ir.keys if isinstance(ne.value, Col)]
309
+ ) == len(ir.keys): # pragma: no cover
310
+ cluster_kind: Literal["dask", "single"]
311
+ if shuffle_method == "rapidsmpf-single":
312
+ from rapidsmpf.integrations.single import rapidsmpf_shuffle_graph
313
+
314
+ cluster_kind = "single"
315
+ else:
316
+ from rapidsmpf.integrations.dask import rapidsmpf_shuffle_graph
317
+
318
+ cluster_kind = "dask"
319
+
320
+ shuffle_on = [k.name for k in _keys]
321
+
322
+ try:
323
+ return rapidsmpf_shuffle_graph(
324
+ get_key_name(ir.children[0]),
325
+ get_key_name(ir),
326
+ partition_info[ir.children[0]].count,
327
+ partition_info[ir].count,
328
+ RMPFIntegration,
329
+ {
330
+ "on": shuffle_on,
331
+ "column_names": list(ir.schema.keys()),
332
+ "dtypes": list(ir.schema.values()),
333
+ "cluster_kind": cluster_kind,
334
+ },
335
+ )
336
+ except ValueError as err:
337
+ # ValueError: rapidsmpf couldn't find a distributed client
338
+ if shuffle_method == "rapidsmpf":
339
+ # Only raise an error if the user specifically
340
+ # set the shuffle method to "rapidsmpf"
341
+ raise ValueError(
342
+ "The current Dask cluster does not support rapidsmpf shuffling."
343
+ ) from err
344
+
345
+ # Simple task-based fall-back
346
+ return _simple_shuffle_graph(
347
+ get_key_name(ir.children[0]),
348
+ get_key_name(ir),
349
+ partition_info[ir.children[0]].count,
350
+ partition_info[ir].count,
351
+ _hash_partition_dataframe,
352
+ None,
353
+ ir.keys,
354
+ )