cudf-polars-cu12 25.2.2__py3-none-any.whl → 25.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. cudf_polars/VERSION +1 -1
  2. cudf_polars/callback.py +82 -65
  3. cudf_polars/containers/column.py +138 -7
  4. cudf_polars/containers/dataframe.py +26 -39
  5. cudf_polars/dsl/expr.py +3 -1
  6. cudf_polars/dsl/expressions/aggregation.py +27 -63
  7. cudf_polars/dsl/expressions/base.py +40 -72
  8. cudf_polars/dsl/expressions/binaryop.py +5 -41
  9. cudf_polars/dsl/expressions/boolean.py +25 -53
  10. cudf_polars/dsl/expressions/datetime.py +97 -17
  11. cudf_polars/dsl/expressions/literal.py +27 -33
  12. cudf_polars/dsl/expressions/rolling.py +110 -9
  13. cudf_polars/dsl/expressions/selection.py +8 -26
  14. cudf_polars/dsl/expressions/slicing.py +47 -0
  15. cudf_polars/dsl/expressions/sorting.py +5 -18
  16. cudf_polars/dsl/expressions/string.py +33 -36
  17. cudf_polars/dsl/expressions/ternary.py +3 -10
  18. cudf_polars/dsl/expressions/unary.py +35 -75
  19. cudf_polars/dsl/ir.py +749 -212
  20. cudf_polars/dsl/nodebase.py +8 -1
  21. cudf_polars/dsl/to_ast.py +5 -3
  22. cudf_polars/dsl/translate.py +319 -171
  23. cudf_polars/dsl/utils/__init__.py +8 -0
  24. cudf_polars/dsl/utils/aggregations.py +292 -0
  25. cudf_polars/dsl/utils/groupby.py +97 -0
  26. cudf_polars/dsl/utils/naming.py +34 -0
  27. cudf_polars/dsl/utils/replace.py +46 -0
  28. cudf_polars/dsl/utils/rolling.py +113 -0
  29. cudf_polars/dsl/utils/windows.py +186 -0
  30. cudf_polars/experimental/base.py +17 -19
  31. cudf_polars/experimental/benchmarks/__init__.py +4 -0
  32. cudf_polars/experimental/benchmarks/pdsh.py +1279 -0
  33. cudf_polars/experimental/dask_registers.py +196 -0
  34. cudf_polars/experimental/distinct.py +174 -0
  35. cudf_polars/experimental/explain.py +127 -0
  36. cudf_polars/experimental/expressions.py +521 -0
  37. cudf_polars/experimental/groupby.py +288 -0
  38. cudf_polars/experimental/io.py +58 -29
  39. cudf_polars/experimental/join.py +353 -0
  40. cudf_polars/experimental/parallel.py +166 -93
  41. cudf_polars/experimental/repartition.py +69 -0
  42. cudf_polars/experimental/scheduler.py +155 -0
  43. cudf_polars/experimental/select.py +92 -7
  44. cudf_polars/experimental/shuffle.py +294 -0
  45. cudf_polars/experimental/sort.py +45 -0
  46. cudf_polars/experimental/spilling.py +151 -0
  47. cudf_polars/experimental/utils.py +100 -0
  48. cudf_polars/testing/asserts.py +146 -6
  49. cudf_polars/testing/io.py +72 -0
  50. cudf_polars/testing/plugin.py +78 -76
  51. cudf_polars/typing/__init__.py +59 -6
  52. cudf_polars/utils/config.py +353 -0
  53. cudf_polars/utils/conversion.py +40 -0
  54. cudf_polars/utils/dtypes.py +22 -5
  55. cudf_polars/utils/timer.py +39 -0
  56. cudf_polars/utils/versions.py +5 -4
  57. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/METADATA +10 -7
  58. cudf_polars_cu12-25.6.0.dist-info/RECORD +73 -0
  59. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/WHEEL +1 -1
  60. cudf_polars/experimental/dask_serialize.py +0 -59
  61. cudf_polars_cu12-25.2.2.dist-info/RECORD +0 -48
  62. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info/licenses}/LICENSE +0 -0
  63. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,69 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Repartitioning Logic."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import itertools
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ from cudf_polars.dsl.ir import IR
11
+ from cudf_polars.experimental.base import get_key_name
12
+ from cudf_polars.experimental.dispatch import generate_ir_tasks
13
+ from cudf_polars.experimental.utils import _concat
14
+
15
+ if TYPE_CHECKING:
16
+ from collections.abc import MutableMapping
17
+
18
+ from cudf_polars.experimental.parallel import PartitionInfo
19
+ from cudf_polars.typing import Schema
20
+
21
+
22
+ class Repartition(IR):
23
+ """
24
+ Repartition a DataFrame.
25
+
26
+ Notes
27
+ -----
28
+ Repartitioning means that we are not modifying any
29
+ data, nor are we reordering or shuffling rows. We
30
+ are only changing the overall partition count. For
31
+ now, we only support an N -> [1...N] repartitioning
32
+ (inclusive). The output partition count is tracked
33
+ separately using PartitionInfo.
34
+ """
35
+
36
+ __slots__ = ()
37
+ _non_child = ("schema",)
38
+
39
+ def __init__(self, schema: Schema, df: IR):
40
+ self.schema = schema
41
+ self._non_child_args = ()
42
+ self.children = (df,)
43
+
44
+
45
+ @generate_ir_tasks.register(Repartition)
46
+ def _(
47
+ ir: Repartition, partition_info: MutableMapping[IR, PartitionInfo]
48
+ ) -> MutableMapping[Any, Any]:
49
+ # Repartition an IR node.
50
+ # Only supports rapartitioning to fewer (for now).
51
+
52
+ (child,) = ir.children
53
+ count_in = partition_info[child].count
54
+ count_out = partition_info[ir].count
55
+
56
+ if count_out > count_in: # pragma: no cover
57
+ raise NotImplementedError(
58
+ f"Repartition {count_in} -> {count_out} not supported."
59
+ )
60
+
61
+ key_name = get_key_name(ir)
62
+ n, remainder = divmod(count_in, count_out)
63
+ # Spread remainder evenly over the partitions.
64
+ offsets = [0, *itertools.accumulate(n + (i < remainder) for i in range(count_out))]
65
+ child_keys = tuple(partition_info[child].keys(child))
66
+ return {
67
+ (key_name, i): (_concat, *child_keys[offsets[i] : offsets[i + 1]])
68
+ for i in range(count_out)
69
+ }
@@ -0,0 +1,155 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Synchronous task scheduler."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from collections import Counter
8
+ from collections.abc import MutableMapping
9
+ from itertools import chain
10
+ from typing import TYPE_CHECKING, Any, TypeVar
11
+
12
+ from typing_extensions import Unpack
13
+
14
+ if TYPE_CHECKING:
15
+ from collections.abc import Mapping
16
+ from typing import TypeAlias
17
+
18
+
19
+ Key: TypeAlias = str | tuple[str, Unpack[tuple[int, ...]]]
20
+ Graph: TypeAlias = MutableMapping[Key, Any]
21
+ T_ = TypeVar("T_")
22
+
23
+
24
+ # NOTE: This is a slimmed-down version of the single-threaded
25
+ # (synchronous) scheduler in `dask.core`.
26
+ #
27
+ # Key Differences:
28
+ # * We do not allow a task to contain a list of key names.
29
+ # Keys must be distinct elements of the task.
30
+ # * We do not support nested tasks.
31
+
32
+
33
+ def istask(x: Any) -> bool:
34
+ """Check if x is a callable task."""
35
+ return isinstance(x, tuple) and bool(x) and callable(x[0])
36
+
37
+
38
+ def is_hashable(x: Any) -> bool:
39
+ """Check if x is hashable."""
40
+ try:
41
+ hash(x)
42
+ except BaseException:
43
+ return False
44
+ else:
45
+ return True
46
+
47
+
48
+ def _execute_task(arg: Any, cache: Mapping) -> Any:
49
+ """Execute a compute task."""
50
+ if istask(arg):
51
+ return arg[0](*(_execute_task(a, cache) for a in arg[1:]))
52
+ elif is_hashable(arg):
53
+ return cache.get(arg, arg)
54
+ else:
55
+ return arg
56
+
57
+
58
+ def required_keys(key: Key, graph: Graph) -> list[Key]:
59
+ """
60
+ Return the dependencies to extract a key from the graph.
61
+
62
+ Parameters
63
+ ----------
64
+ key
65
+ Root key we want to extract.
66
+ graph
67
+ The full task graph.
68
+
69
+ Returns
70
+ -------
71
+ List of other keys needed to extract ``key``.
72
+ """
73
+ maybe_task = graph[key]
74
+ return [
75
+ k
76
+ for k in (
77
+ maybe_task[1:]
78
+ if istask(maybe_task)
79
+ else [maybe_task] # maybe_task might be a key
80
+ )
81
+ if is_hashable(k) and k in graph
82
+ ]
83
+
84
+
85
+ def toposort(graph: Graph, dependencies: Mapping[Key, list[Key]]) -> list[Key]:
86
+ """Return a list of task keys sorted in topological order."""
87
+ # Stack-based depth-first search traversal. This is based on Tarjan's
88
+ # algorithm for strongly-connected components
89
+ # (https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm)
90
+ ordered: list[Key] = []
91
+ completed: set[Key] = set()
92
+
93
+ for key in graph:
94
+ if key in completed:
95
+ continue
96
+ nodes = [key]
97
+ while nodes:
98
+ # Keep current node on the stack until all descendants are visited
99
+ current = nodes[-1]
100
+ if current in completed: # pragma: no cover
101
+ # Already fully traversed descendants of current
102
+ nodes.pop()
103
+ continue
104
+
105
+ # Add direct descendants of current to nodes stack
106
+ next_nodes = set(dependencies[current]) - completed
107
+ if next_nodes:
108
+ nodes.extend(next_nodes)
109
+ else:
110
+ # Current has no more descendants to explore
111
+ ordered.append(current)
112
+ completed.add(current)
113
+ nodes.pop()
114
+
115
+ return ordered
116
+
117
+
118
+ def synchronous_scheduler(
119
+ graph: Graph,
120
+ key: Key,
121
+ *,
122
+ cache: MutableMapping | None = None,
123
+ ) -> Any:
124
+ """
125
+ Execute the task graph for a given key.
126
+
127
+ Parameters
128
+ ----------
129
+ graph
130
+ The task graph to execute.
131
+ key
132
+ The final output key to extract from the graph.
133
+ cache
134
+ Intermediate-data cache.
135
+
136
+ Returns
137
+ -------
138
+ Executed task-graph result for ``key``.
139
+ """
140
+ if key not in graph: # pragma: no cover
141
+ raise KeyError(f"{key} is not a key in the graph")
142
+ if cache is None:
143
+ cache = {}
144
+
145
+ dependencies = {k: required_keys(k, graph) for k in graph}
146
+ refcount = Counter(chain.from_iterable(dependencies.values()))
147
+
148
+ for k in toposort(graph, dependencies):
149
+ cache[k] = _execute_task(graph[k], cache)
150
+ for dep in dependencies[k]:
151
+ refcount[dep] -= 1
152
+ if refcount[dep] == 0 and dep != key:
153
+ del cache[dep]
154
+
155
+ return cache[key]
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  """Parallel Select Logic."""
4
4
 
@@ -6,16 +6,95 @@ from __future__ import annotations
6
6
 
7
7
  from typing import TYPE_CHECKING
8
8
 
9
- from cudf_polars.dsl.ir import Select
9
+ from cudf_polars.dsl.ir import HConcat, Select
10
10
  from cudf_polars.dsl.traversal import traversal
11
+ from cudf_polars.experimental.base import PartitionInfo
11
12
  from cudf_polars.experimental.dispatch import lower_ir_node
13
+ from cudf_polars.experimental.expressions import decompose_expr_graph
14
+ from cudf_polars.experimental.utils import _lower_ir_fallback
12
15
 
13
16
  if TYPE_CHECKING:
14
17
  from collections.abc import MutableMapping
15
18
 
16
19
  from cudf_polars.dsl.ir import IR
17
- from cudf_polars.experimental.base import PartitionInfo
18
20
  from cudf_polars.experimental.parallel import LowerIRTransformer
21
+ from cudf_polars.utils.config import ConfigOptions
22
+
23
+
24
+ def decompose_select(
25
+ select_ir: Select,
26
+ input_ir: IR,
27
+ partition_info: MutableMapping[IR, PartitionInfo],
28
+ config_options: ConfigOptions,
29
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
30
+ """
31
+ Decompose a multi-partition Select operation.
32
+
33
+ Parameters
34
+ ----------
35
+ select_ir
36
+ The original Select operation to decompose.
37
+ This object has not been reconstructed with
38
+ ``input_ir`` as its child yet.
39
+ input_ir
40
+ The lowered child of ``select_ir``. This object
41
+ will be decomposed into a "partial" selection
42
+ for each element of ``select_ir.exprs``.
43
+ partition_info
44
+ A mapping from all unique IR nodes to the
45
+ associated partitioning information.
46
+ config_options
47
+ GPUEngine configuration options.
48
+
49
+ Returns
50
+ -------
51
+ new_ir, partition_info
52
+ The rewritten Select node, and a mapping from
53
+ unique nodes in the new graph to associated
54
+ partitioning information.
55
+
56
+ Notes
57
+ -----
58
+ This function uses ``decompose_expr_graph`` to further
59
+ decompose each element of ``select_ir.exprs``.
60
+
61
+ See Also
62
+ --------
63
+ decompose_expr_graph
64
+ """
65
+ # Collect partial selections
66
+ selections = []
67
+ for ne in select_ir.exprs:
68
+ # Decompose this partial expression
69
+ new_ne, partial_input_ir, _partition_info = decompose_expr_graph(
70
+ ne, input_ir, partition_info, config_options
71
+ )
72
+ pi = _partition_info[partial_input_ir]
73
+ partial_input_ir = Select(
74
+ {ne.name: ne.value.dtype},
75
+ [new_ne],
76
+ True, # noqa: FBT003
77
+ partial_input_ir,
78
+ )
79
+ _partition_info[partial_input_ir] = pi
80
+ partition_info.update(_partition_info)
81
+ selections.append(partial_input_ir)
82
+
83
+ # Concatenate partial selections
84
+ new_ir: HConcat | Select
85
+ if len(selections) > 1:
86
+ new_ir = HConcat(
87
+ select_ir.schema,
88
+ True, # noqa: FBT003
89
+ *selections,
90
+ )
91
+ partition_info[new_ir] = PartitionInfo(
92
+ count=max(partition_info[c].count for c in selections)
93
+ )
94
+ else:
95
+ new_ir = selections[0]
96
+
97
+ return new_ir, partition_info
19
98
 
20
99
 
21
100
  @lower_ir_node.register(Select)
@@ -27,10 +106,16 @@ def _(
27
106
  if pi.count > 1 and not all(
28
107
  expr.is_pointwise for expr in traversal([e.value for e in ir.exprs])
29
108
  ):
30
- # TODO: Handle non-pointwise expressions.
31
- raise NotImplementedError(
32
- f"Selection {ir} does not support multiple partitions."
33
- )
109
+ try:
110
+ # Try decomposing the underlying expressions
111
+ return decompose_select(
112
+ ir, child, partition_info, rec.state["config_options"]
113
+ )
114
+ except NotImplementedError:
115
+ return _lower_ir_fallback(
116
+ ir, rec, msg="This selection is not supported for multiple partitions."
117
+ )
118
+
34
119
  new_node = ir.reconstruct([child])
35
120
  partition_info[new_node] = pi
36
121
  return new_node, partition_info
@@ -0,0 +1,294 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Shuffle Logic."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import operator
8
+ from typing import TYPE_CHECKING, Any, TypedDict
9
+
10
+ import pylibcudf as plc
11
+ import rmm.mr
12
+ from rmm.pylibrmm.stream import DEFAULT_STREAM
13
+
14
+ from cudf_polars.containers import DataFrame
15
+ from cudf_polars.dsl.expr import Col
16
+ from cudf_polars.dsl.ir import IR
17
+ from cudf_polars.experimental.base import get_key_name
18
+ from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node
19
+ from cudf_polars.experimental.utils import _concat
20
+
21
+ if TYPE_CHECKING:
22
+ from collections.abc import MutableMapping, Sequence
23
+
24
+ from cudf_polars.dsl.expr import NamedExpr
25
+ from cudf_polars.experimental.dispatch import LowerIRTransformer
26
+ from cudf_polars.experimental.parallel import PartitionInfo
27
+ from cudf_polars.typing import Schema
28
+ from cudf_polars.utils.config import ConfigOptions
29
+
30
+
31
+ # Supported shuffle methods
32
+ _SHUFFLE_METHODS = ("rapidsmpf", "tasks")
33
+
34
+
35
+ class ShuffleOptions(TypedDict):
36
+ """RapidsMPF shuffling options."""
37
+
38
+ on: Sequence[str]
39
+ column_names: Sequence[str]
40
+
41
+
42
+ # Experimental rapidsmpf shuffler integration
43
+ class RMPFIntegration: # pragma: no cover
44
+ """cuDF-Polars protocol for rapidsmpf shuffler."""
45
+
46
+ @staticmethod
47
+ def insert_partition(
48
+ df: DataFrame,
49
+ partition_id: int, # Not currently used
50
+ partition_count: int,
51
+ shuffler: Any,
52
+ options: ShuffleOptions,
53
+ *other: Any,
54
+ ) -> None:
55
+ """Add cudf-polars DataFrame chunks to an RMP shuffler."""
56
+ from rapidsmpf.shuffler import partition_and_pack
57
+
58
+ on = options["on"]
59
+ assert not other, f"Unexpected arguments: {other}"
60
+ columns_to_hash = tuple(df.column_names.index(val) for val in on)
61
+ packed_inputs = partition_and_pack(
62
+ df.table,
63
+ columns_to_hash=columns_to_hash,
64
+ num_partitions=partition_count,
65
+ stream=DEFAULT_STREAM,
66
+ device_mr=rmm.mr.get_current_device_resource(),
67
+ )
68
+ shuffler.insert_chunks(packed_inputs)
69
+
70
+ @staticmethod
71
+ def extract_partition(
72
+ partition_id: int,
73
+ shuffler: Any,
74
+ options: ShuffleOptions,
75
+ ) -> DataFrame:
76
+ """Extract a finished partition from the RMP shuffler."""
77
+ from rapidsmpf.shuffler import unpack_and_concat
78
+
79
+ shuffler.wait_on(partition_id)
80
+ column_names = options["column_names"]
81
+ return DataFrame.from_table(
82
+ unpack_and_concat(
83
+ shuffler.extract(partition_id),
84
+ stream=DEFAULT_STREAM,
85
+ device_mr=rmm.mr.get_current_device_resource(),
86
+ ),
87
+ column_names,
88
+ )
89
+
90
+
91
+ class Shuffle(IR):
92
+ """
93
+ Shuffle multi-partition data.
94
+
95
+ Notes
96
+ -----
97
+ Only hash-based partitioning is supported (for now).
98
+ """
99
+
100
+ __slots__ = ("config_options", "keys")
101
+ _non_child = ("schema", "keys", "config_options")
102
+ keys: tuple[NamedExpr, ...]
103
+ """Keys to shuffle on."""
104
+ config_options: ConfigOptions
105
+ """Configuration options."""
106
+
107
+ def __init__(
108
+ self,
109
+ schema: Schema,
110
+ keys: tuple[NamedExpr, ...],
111
+ config_options: ConfigOptions,
112
+ df: IR,
113
+ ):
114
+ self.schema = schema
115
+ self.keys = keys
116
+ self.config_options = config_options
117
+ self._non_child_args = (schema, keys, config_options)
118
+ self.children = (df,)
119
+
120
+ @classmethod
121
+ def do_evaluate(
122
+ cls,
123
+ schema: Schema,
124
+ keys: tuple[NamedExpr, ...],
125
+ config_options: ConfigOptions,
126
+ df: DataFrame,
127
+ ) -> DataFrame: # pragma: no cover
128
+ """Evaluate and return a dataframe."""
129
+ # Single-partition Shuffle evaluation is a no-op
130
+ return df
131
+
132
+
133
+ def _partition_dataframe(
134
+ df: DataFrame,
135
+ keys: tuple[NamedExpr, ...],
136
+ count: int,
137
+ ) -> dict[int, DataFrame]:
138
+ """
139
+ Partition an input DataFrame for shuffling.
140
+
141
+ Notes
142
+ -----
143
+ This utility only supports hash partitioning (for now).
144
+
145
+ Parameters
146
+ ----------
147
+ df
148
+ DataFrame to partition.
149
+ keys
150
+ Shuffle key(s).
151
+ count
152
+ Total number of output partitions.
153
+
154
+ Returns
155
+ -------
156
+ A dictionary mapping between int partition indices and
157
+ DataFrame fragments.
158
+ """
159
+ if df.num_rows == 0:
160
+ # Fast path for empty DataFrame
161
+ return {i: df for i in range(count)}
162
+
163
+ # Hash the specified keys to calculate the output
164
+ # partition for each row
165
+ partition_map = plc.binaryop.binary_operation(
166
+ plc.hashing.murmurhash3_x86_32(
167
+ DataFrame([expr.evaluate(df) for expr in keys]).table
168
+ ),
169
+ plc.Scalar.from_py(count, plc.DataType(plc.TypeId.UINT32)),
170
+ plc.binaryop.BinaryOperator.PYMOD,
171
+ plc.types.DataType(plc.types.TypeId.UINT32),
172
+ )
173
+
174
+ # Apply partitioning
175
+ t, offsets = plc.partitioning.partition(
176
+ df.table,
177
+ partition_map,
178
+ count,
179
+ )
180
+
181
+ # Split and return the partitioned result
182
+ return {
183
+ i: DataFrame.from_table(
184
+ split,
185
+ df.column_names,
186
+ )
187
+ for i, split in enumerate(plc.copying.split(t, offsets[1:-1]))
188
+ }
189
+
190
+
191
+ def _simple_shuffle_graph(
192
+ name_in: str,
193
+ name_out: str,
194
+ keys: tuple[NamedExpr, ...],
195
+ count_in: int,
196
+ count_out: int,
197
+ ) -> MutableMapping[Any, Any]:
198
+ """Make a simple all-to-all shuffle graph."""
199
+ split_name = f"split-{name_out}"
200
+ inter_name = f"inter-{name_out}"
201
+
202
+ graph: MutableMapping[Any, Any] = {}
203
+ for part_out in range(count_out):
204
+ _concat_list = []
205
+ for part_in in range(count_in):
206
+ graph[(split_name, part_in)] = (
207
+ _partition_dataframe,
208
+ (name_in, part_in),
209
+ keys,
210
+ count_out,
211
+ )
212
+ _concat_list.append((inter_name, part_out, part_in))
213
+ graph[_concat_list[-1]] = (
214
+ operator.getitem,
215
+ (split_name, part_in),
216
+ part_out,
217
+ )
218
+ graph[(name_out, part_out)] = (_concat, *_concat_list)
219
+ return graph
220
+
221
+
222
+ @lower_ir_node.register(Shuffle)
223
+ def _(
224
+ ir: Shuffle, rec: LowerIRTransformer
225
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
226
+ # Simple lower_ir_node handling for the default hash-based shuffle.
227
+ # More-complex logic (e.g. joining and sorting) should
228
+ # be handled separately.
229
+ from cudf_polars.experimental.parallel import PartitionInfo
230
+
231
+ (child,) = ir.children
232
+
233
+ new_child, pi = rec(child)
234
+ if pi[new_child].count == 1 or ir.keys == pi[new_child].partitioned_on:
235
+ # Already shuffled
236
+ return new_child, pi
237
+ new_node = ir.reconstruct([new_child])
238
+ pi[new_node] = PartitionInfo(
239
+ # Default shuffle preserves partition count
240
+ count=pi[new_child].count,
241
+ # Add partitioned_on info
242
+ partitioned_on=ir.keys,
243
+ )
244
+ return new_node, pi
245
+
246
+
247
+ @generate_ir_tasks.register(Shuffle)
248
+ def _(
249
+ ir: Shuffle, partition_info: MutableMapping[IR, PartitionInfo]
250
+ ) -> MutableMapping[Any, Any]:
251
+ # Extract "shuffle_method" configuration
252
+ assert ir.config_options.executor.name == "streaming", (
253
+ "'in-memory' executor not supported in 'generate_ir_tasks'"
254
+ )
255
+
256
+ shuffle_method = ir.config_options.executor.shuffle_method
257
+
258
+ # Try using rapidsmpf shuffler if we have "simple" shuffle
259
+ # keys, and the "shuffle_method" config is set to "rapidsmpf"
260
+ _keys: list[Col]
261
+ if shuffle_method in (None, "rapidsmpf") and len(
262
+ _keys := [ne.value for ne in ir.keys if isinstance(ne.value, Col)]
263
+ ) == len(ir.keys): # pragma: no cover
264
+ shuffle_on = [k.name for k in _keys]
265
+ try:
266
+ from rapidsmpf.integrations.dask import rapidsmpf_shuffle_graph
267
+
268
+ return rapidsmpf_shuffle_graph(
269
+ get_key_name(ir.children[0]),
270
+ get_key_name(ir),
271
+ partition_info[ir.children[0]].count,
272
+ partition_info[ir].count,
273
+ RMPFIntegration,
274
+ {"on": shuffle_on, "column_names": list(ir.schema.keys())},
275
+ )
276
+ except (ImportError, ValueError) as err:
277
+ # ImportError: rapidsmpf is not installed
278
+ # ValueError: rapidsmpf couldn't find a distributed client
279
+ if shuffle_method == "rapidsmpf":
280
+ # Only raise an error if the user specifically
281
+ # set the shuffle method to "rapidsmpf"
282
+ raise ValueError(
283
+ "Rapidsmp is not installed correctly or the current "
284
+ "Dask cluster does not support rapidsmpf shuffling."
285
+ ) from err
286
+
287
+ # Simple task-based fall-back
288
+ return _simple_shuffle_graph(
289
+ get_key_name(ir.children[0]),
290
+ get_key_name(ir),
291
+ ir.keys,
292
+ partition_info[ir.children[0]].count,
293
+ partition_info[ir].count,
294
+ )
@@ -0,0 +1,45 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Sorting Logic."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import TYPE_CHECKING
8
+
9
+ from cudf_polars.dsl.ir import Sort
10
+ from cudf_polars.experimental.base import PartitionInfo
11
+ from cudf_polars.experimental.dispatch import lower_ir_node
12
+ from cudf_polars.experimental.repartition import Repartition
13
+ from cudf_polars.experimental.utils import _lower_ir_fallback
14
+
15
+ if TYPE_CHECKING:
16
+ from collections.abc import MutableMapping
17
+
18
+ from cudf_polars.dsl.ir import IR
19
+ from cudf_polars.experimental.dispatch import LowerIRTransformer
20
+
21
+
22
+ @lower_ir_node.register(Sort)
23
+ def _(
24
+ ir: Sort, rec: LowerIRTransformer
25
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
26
+ # Special handling for slicing
27
+ # (May be a top- or bottom-k operation)
28
+
29
+ if ir.zlice is not None and ir.zlice[0] < 1:
30
+ # TODO: Handle large slices (e.g. 1m+ rows)
31
+ from cudf_polars.experimental.parallel import _lower_ir_pwise
32
+
33
+ # Sort input partitions
34
+ new_node, partition_info = _lower_ir_pwise(ir, rec)
35
+ if partition_info[new_node].count > 1:
36
+ # Collapse down to single partition
37
+ inter = Repartition(new_node.schema, new_node)
38
+ partition_info[inter] = PartitionInfo(count=1)
39
+ # Sort reduced partition
40
+ new_node = ir.reconstruct([inter])
41
+ partition_info[new_node] = PartitionInfo(count=1)
42
+ return new_node, partition_info
43
+
44
+ # Fallback
45
+ return _lower_ir_fallback(ir, rec, msg="Sort does not support multiple partitions.")