cudf-polars-cu12 25.2.2__py3-none-any.whl → 25.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. cudf_polars/VERSION +1 -1
  2. cudf_polars/callback.py +85 -53
  3. cudf_polars/containers/column.py +100 -7
  4. cudf_polars/containers/dataframe.py +16 -24
  5. cudf_polars/dsl/expr.py +3 -1
  6. cudf_polars/dsl/expressions/aggregation.py +3 -3
  7. cudf_polars/dsl/expressions/binaryop.py +2 -2
  8. cudf_polars/dsl/expressions/boolean.py +4 -4
  9. cudf_polars/dsl/expressions/datetime.py +39 -1
  10. cudf_polars/dsl/expressions/literal.py +3 -9
  11. cudf_polars/dsl/expressions/selection.py +2 -2
  12. cudf_polars/dsl/expressions/slicing.py +53 -0
  13. cudf_polars/dsl/expressions/sorting.py +1 -1
  14. cudf_polars/dsl/expressions/string.py +4 -4
  15. cudf_polars/dsl/expressions/unary.py +3 -2
  16. cudf_polars/dsl/ir.py +222 -93
  17. cudf_polars/dsl/nodebase.py +8 -1
  18. cudf_polars/dsl/translate.py +66 -38
  19. cudf_polars/experimental/base.py +18 -12
  20. cudf_polars/experimental/dask_serialize.py +22 -8
  21. cudf_polars/experimental/groupby.py +346 -0
  22. cudf_polars/experimental/io.py +13 -11
  23. cudf_polars/experimental/join.py +318 -0
  24. cudf_polars/experimental/parallel.py +57 -6
  25. cudf_polars/experimental/shuffle.py +194 -0
  26. cudf_polars/testing/plugin.py +23 -34
  27. cudf_polars/typing/__init__.py +33 -2
  28. cudf_polars/utils/config.py +138 -0
  29. cudf_polars/utils/conversion.py +40 -0
  30. cudf_polars/utils/dtypes.py +14 -4
  31. cudf_polars/utils/timer.py +39 -0
  32. cudf_polars/utils/versions.py +4 -3
  33. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.4.0.dist-info}/METADATA +8 -7
  34. cudf_polars_cu12-25.4.0.dist-info/RECORD +55 -0
  35. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.4.0.dist-info}/WHEEL +1 -1
  36. cudf_polars_cu12-25.2.2.dist-info/RECORD +0 -48
  37. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.4.0.dist-info/licenses}/LICENSE +0 -0
  38. {cudf_polars_cu12-25.2.2.dist-info → cudf_polars_cu12-25.4.0.dist-info}/top_level.txt +0 -0
@@ -22,14 +22,16 @@ if TYPE_CHECKING:
22
22
  from cudf_polars.dsl.expr import NamedExpr
23
23
  from cudf_polars.experimental.dispatch import LowerIRTransformer
24
24
  from cudf_polars.typing import Schema
25
+ from cudf_polars.utils.config import ConfigOptions
25
26
 
26
27
 
27
28
  @lower_ir_node.register(DataFrameScan)
28
29
  def _(
29
30
  ir: DataFrameScan, rec: LowerIRTransformer
30
31
  ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
31
- rows_per_partition = ir.config_options.get("executor_options", {}).get(
32
- "max_rows_per_partition", 1_000_000
32
+ rows_per_partition = ir.config_options.get(
33
+ "executor_options.max_rows_per_partition",
34
+ default=1_000_000,
33
35
  )
34
36
 
35
37
  nrows = max(ir.df.shape()[0], 1)
@@ -91,8 +93,10 @@ class ScanPartitionPlan:
91
93
  """Extract the partitioning plan of a Scan operation."""
92
94
  if ir.typ == "parquet":
93
95
  # TODO: Use system info to set default blocksize
94
- parallel_options = ir.config_options.get("executor_options", {})
95
- blocksize: int = parallel_options.get("parquet_blocksize", 1024**3)
96
+ blocksize: int = ir.config_options.get(
97
+ "executor_options.parquet_blocksize",
98
+ default=1024**3,
99
+ )
96
100
  stats = _sample_pq_statistics(ir)
97
101
  file_size = sum(float(stats[column]) for column in ir.schema)
98
102
  if file_size > 0:
@@ -168,7 +172,7 @@ class SplitScan(IR):
168
172
  schema: Schema,
169
173
  typ: str,
170
174
  reader_options: dict[str, Any],
171
- config_options: dict[str, Any],
175
+ config_options: ConfigOptions,
172
176
  paths: list[str],
173
177
  with_columns: list[str] | None,
174
178
  skip_rows: int,
@@ -243,7 +247,7 @@ def _sample_pq_statistics(ir: Scan) -> dict[str, float]:
243
247
 
244
248
  # Use average total_uncompressed_size of three files
245
249
  # TODO: Use plc.io.parquet_metadata.read_parquet_metadata
246
- n_sample = 3
250
+ n_sample = min(3, len(ir.paths))
247
251
  column_sizes = {}
248
252
  ds = pa_ds.dataset(random.sample(ir.paths, n_sample), format="parquet")
249
253
  for i, frag in enumerate(ds.get_fragments()):
@@ -270,11 +274,9 @@ def _(
270
274
  paths = list(ir.paths)
271
275
  if plan.flavor == ScanPartitionFlavor.SPLIT_FILES:
272
276
  # Disable chunked reader when splitting files
273
- config_options = ir.config_options.copy()
274
- config_options["parquet_options"] = config_options.get(
275
- "parquet_options", {}
276
- ).copy()
277
- config_options["parquet_options"]["chunked"] = False
277
+ config_options = ir.config_options.set(
278
+ name="parquet_options.chunked", value=False
279
+ )
278
280
 
279
281
  slices: list[SplitScan] = []
280
282
  for path in paths:
@@ -0,0 +1,318 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Parallel Join Logic."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import operator
8
+ from functools import reduce
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ from cudf_polars.dsl.ir import Join
12
+ from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name
13
+ from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node
14
+ from cudf_polars.experimental.shuffle import Shuffle, _partition_dataframe
15
+
16
+ if TYPE_CHECKING:
17
+ from collections.abc import MutableMapping
18
+
19
+ from cudf_polars.dsl.expr import NamedExpr
20
+ from cudf_polars.dsl.ir import IR
21
+ from cudf_polars.experimental.parallel import LowerIRTransformer
22
+ from cudf_polars.utils.config import ConfigOptions
23
+
24
+
25
+ def _maybe_shuffle_frame(
26
+ frame: IR,
27
+ on: tuple[NamedExpr, ...],
28
+ partition_info: MutableMapping[IR, PartitionInfo],
29
+ config_options: ConfigOptions,
30
+ output_count: int,
31
+ ) -> IR:
32
+ # Shuffle `frame` if it isn't already shuffled.
33
+ if (
34
+ partition_info[frame].partitioned_on == on
35
+ and partition_info[frame].count == output_count
36
+ ):
37
+ # Already shuffled
38
+ return frame
39
+ else:
40
+ # Insert new Shuffle node
41
+ frame = Shuffle(
42
+ frame.schema,
43
+ on,
44
+ config_options,
45
+ frame,
46
+ )
47
+ partition_info[frame] = PartitionInfo(
48
+ count=output_count,
49
+ partitioned_on=on,
50
+ )
51
+ return frame
52
+
53
+
54
+ def _make_hash_join(
55
+ ir: Join,
56
+ output_count: int,
57
+ partition_info: MutableMapping[IR, PartitionInfo],
58
+ left: IR,
59
+ right: IR,
60
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
61
+ # Shuffle left and right dataframes (if necessary)
62
+ new_left = _maybe_shuffle_frame(
63
+ left,
64
+ ir.left_on,
65
+ partition_info,
66
+ ir.config_options,
67
+ output_count,
68
+ )
69
+ new_right = _maybe_shuffle_frame(
70
+ right,
71
+ ir.right_on,
72
+ partition_info,
73
+ ir.config_options,
74
+ output_count,
75
+ )
76
+ if left != new_left or right != new_right:
77
+ ir = ir.reconstruct([new_left, new_right])
78
+ left = new_left
79
+ right = new_right
80
+
81
+ # Record new partitioning info
82
+ partitioned_on: tuple[NamedExpr, ...] = ()
83
+ if ir.left_on == ir.right_on or (ir.options[0] in ("Left", "Semi", "Anti")):
84
+ partitioned_on = ir.left_on
85
+ elif ir.options[0] == "Right":
86
+ partitioned_on = ir.right_on
87
+ partition_info[ir] = PartitionInfo(
88
+ count=output_count,
89
+ partitioned_on=partitioned_on,
90
+ )
91
+
92
+ return ir, partition_info
93
+
94
+
95
+ def _should_bcast_join(
96
+ ir: Join,
97
+ left: IR,
98
+ right: IR,
99
+ partition_info: MutableMapping[IR, PartitionInfo],
100
+ output_count: int,
101
+ ) -> bool:
102
+ # Decide if a broadcast join is appropriate.
103
+ if partition_info[left].count >= partition_info[right].count:
104
+ small_count = partition_info[right].count
105
+ large = left
106
+ large_on = ir.left_on
107
+ else:
108
+ small_count = partition_info[left].count
109
+ large = right
110
+ large_on = ir.right_on
111
+
112
+ # Avoid the broadcast if the "large" table is already shuffled
113
+ large_shuffled = (
114
+ partition_info[large].partitioned_on == large_on
115
+ and partition_info[large].count == output_count
116
+ )
117
+
118
+ # Broadcast-Join Criteria:
119
+ # 1. Large dataframe isn't already shuffled
120
+ # 2. Small dataframe has 8 partitions (or fewer).
121
+ # TODO: Make this value/heuristic configurable).
122
+ # We may want to account for the number of workers.
123
+ # 3. The "kind" of join is compatible with a broadcast join
124
+ return (
125
+ not large_shuffled
126
+ and small_count
127
+ <= ir.config_options.get(
128
+ # Maximum number of "small"-table partitions to bcast
129
+ "executor_options.broadcast_join_limit",
130
+ default=16,
131
+ )
132
+ and (
133
+ ir.options[0] == "Inner"
134
+ or (ir.options[0] in ("Left", "Semi", "Anti") and large == left)
135
+ or (ir.options[0] == "Right" and large == right)
136
+ )
137
+ )
138
+
139
+
140
+ def _make_bcast_join(
141
+ ir: Join,
142
+ output_count: int,
143
+ partition_info: MutableMapping[IR, PartitionInfo],
144
+ left: IR,
145
+ right: IR,
146
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
147
+ if ir.options[0] != "Inner":
148
+ left_count = partition_info[left].count
149
+ right_count = partition_info[right].count
150
+
151
+ # Shuffle the smaller table (if necessary) - Notes:
152
+ # - We need to shuffle the smaller table if
153
+ # (1) we are not doing an "inner" join,
154
+ # and (2) the small table contains multiple
155
+ # partitions.
156
+ # - We cannot simply join a large-table partition
157
+ # to each small-table partition, and then
158
+ # concatenate the partial-join results, because
159
+ # a non-"inner" join does NOT commute with
160
+ # concatenation.
161
+ # - In some cases, we can perform the partial joins
162
+ # sequentially. However, we are starting with a
163
+ # catch-all algorithm that works for all cases.
164
+ if left_count >= right_count:
165
+ right = _maybe_shuffle_frame(
166
+ right,
167
+ ir.right_on,
168
+ partition_info,
169
+ ir.config_options,
170
+ right_count,
171
+ )
172
+ else:
173
+ left = _maybe_shuffle_frame(
174
+ left,
175
+ ir.left_on,
176
+ partition_info,
177
+ ir.config_options,
178
+ left_count,
179
+ )
180
+
181
+ new_node = ir.reconstruct([left, right])
182
+ partition_info[new_node] = PartitionInfo(count=output_count)
183
+ return new_node, partition_info
184
+
185
+
186
+ @lower_ir_node.register(Join)
187
+ def _(
188
+ ir: Join, rec: LowerIRTransformer
189
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
190
+ # Lower children
191
+ children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True)
192
+ partition_info = reduce(operator.or_, _partition_info)
193
+
194
+ left, right = children
195
+ output_count = max(partition_info[left].count, partition_info[right].count)
196
+ if output_count == 1:
197
+ new_node = ir.reconstruct(children)
198
+ partition_info[new_node] = PartitionInfo(count=1)
199
+ return new_node, partition_info
200
+ elif ir.options[0] == "Cross":
201
+ raise NotImplementedError(
202
+ "Cross join not support for multiple partitions."
203
+ ) # pragma: no cover
204
+
205
+ if _should_bcast_join(ir, left, right, partition_info, output_count):
206
+ # Create a broadcast join
207
+ return _make_bcast_join(
208
+ ir,
209
+ output_count,
210
+ partition_info,
211
+ left,
212
+ right,
213
+ )
214
+ else:
215
+ # Create a hash join
216
+ return _make_hash_join(
217
+ ir,
218
+ output_count,
219
+ partition_info,
220
+ left,
221
+ right,
222
+ )
223
+
224
+
225
+ @generate_ir_tasks.register(Join)
226
+ def _(
227
+ ir: Join, partition_info: MutableMapping[IR, PartitionInfo]
228
+ ) -> MutableMapping[Any, Any]:
229
+ left, right = ir.children
230
+ output_count = partition_info[ir].count
231
+
232
+ left_partitioned = (
233
+ partition_info[left].partitioned_on == ir.left_on
234
+ and partition_info[left].count == output_count
235
+ )
236
+ right_partitioned = (
237
+ partition_info[right].partitioned_on == ir.right_on
238
+ and partition_info[right].count == output_count
239
+ )
240
+
241
+ if output_count == 1 or (left_partitioned and right_partitioned):
242
+ # Partition-wise join
243
+ left_name = get_key_name(left)
244
+ right_name = get_key_name(right)
245
+ return {
246
+ key: (
247
+ ir.do_evaluate,
248
+ *ir._non_child_args,
249
+ (left_name, i),
250
+ (right_name, i),
251
+ )
252
+ for i, key in enumerate(partition_info[ir].keys(ir))
253
+ }
254
+ else:
255
+ # Broadcast join
256
+ left_parts = partition_info[left]
257
+ right_parts = partition_info[right]
258
+ if left_parts.count >= right_parts.count:
259
+ small_side = "Right"
260
+ small_name = get_key_name(right)
261
+ small_size = partition_info[right].count
262
+ large_name = get_key_name(left)
263
+ large_on = ir.left_on
264
+ else:
265
+ small_side = "Left"
266
+ small_name = get_key_name(left)
267
+ small_size = partition_info[left].count
268
+ large_name = get_key_name(right)
269
+ large_on = ir.right_on
270
+
271
+ graph: MutableMapping[Any, Any] = {}
272
+
273
+ out_name = get_key_name(ir)
274
+ out_size = partition_info[ir].count
275
+ split_name = f"split-{out_name}"
276
+ inter_name = f"inter-{out_name}"
277
+
278
+ for part_out in range(out_size):
279
+ if ir.options[0] != "Inner":
280
+ graph[(split_name, part_out)] = (
281
+ _partition_dataframe,
282
+ (large_name, part_out),
283
+ large_on,
284
+ small_size,
285
+ )
286
+
287
+ _concat_list = []
288
+ for j in range(small_size):
289
+ join_children = [
290
+ (
291
+ (
292
+ operator.getitem,
293
+ (split_name, part_out),
294
+ j,
295
+ )
296
+ if ir.options[0] != "Inner"
297
+ else (large_name, part_out)
298
+ ),
299
+ (small_name, j),
300
+ ]
301
+ if small_side == "Left":
302
+ join_children.reverse()
303
+
304
+ inter_key = (inter_name, part_out, j)
305
+ graph[(inter_name, part_out, j)] = (
306
+ ir.do_evaluate,
307
+ ir.left_on,
308
+ ir.right_on,
309
+ ir.options,
310
+ *join_children,
311
+ )
312
+ _concat_list.append(inter_key)
313
+ if len(_concat_list) == 1:
314
+ graph[(out_name, part_out)] = graph.pop(_concat_list[0])
315
+ else:
316
+ graph[(out_name, part_out)] = (_concat, *_concat_list)
317
+
318
+ return graph
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  """Multi-partition Dask execution."""
4
4
 
@@ -7,10 +7,13 @@ from __future__ import annotations
7
7
  import itertools
8
8
  import operator
9
9
  from functools import reduce
10
- from typing import TYPE_CHECKING, Any
10
+ from typing import TYPE_CHECKING, Any, ClassVar
11
11
 
12
+ import cudf_polars.experimental.groupby
12
13
  import cudf_polars.experimental.io
13
- import cudf_polars.experimental.select # noqa: F401
14
+ import cudf_polars.experimental.join
15
+ import cudf_polars.experimental.select
16
+ import cudf_polars.experimental.shuffle # noqa: F401
14
17
  from cudf_polars.dsl.ir import IR, Cache, Filter, HStack, Projection, Select, Union
15
18
  from cudf_polars.dsl.traversal import CachingVisitor, traversal
16
19
  from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name
@@ -22,10 +25,38 @@ from cudf_polars.experimental.dispatch import (
22
25
  if TYPE_CHECKING:
23
26
  from collections.abc import MutableMapping
24
27
 
28
+ from distributed import Client
29
+
25
30
  from cudf_polars.containers import DataFrame
26
31
  from cudf_polars.experimental.dispatch import LowerIRTransformer
27
32
 
28
33
 
34
+ class SerializerManager:
35
+ """Manager to ensure ensure serializer is only registered once."""
36
+
37
+ _serializer_registered: bool = False
38
+ _client_run_executed: ClassVar[set[str]] = set()
39
+
40
+ @classmethod
41
+ def register_serialize(cls) -> None:
42
+ """Register Dask/cudf-polars serializers in calling process."""
43
+ if not cls._serializer_registered:
44
+ from cudf_polars.experimental.dask_serialize import register
45
+
46
+ register()
47
+ cls._serializer_registered = True
48
+
49
+ @classmethod
50
+ def run_on_cluster(cls, client: Client) -> None:
51
+ """Run serializer registration on the workers and scheduler."""
52
+ if (
53
+ client.id not in cls._client_run_executed
54
+ ): # pragma: no cover; Only executes with Distributed scheduler
55
+ client.run(cls.register_serialize)
56
+ client.run_on_scheduler(cls.register_serialize)
57
+ cls._client_run_executed.add(client.id)
58
+
59
+
29
60
  @lower_ir_node.register(IR)
30
61
  def _(ir: IR, rec: LowerIRTransformer) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
31
62
  # Default logic - Requires single partition
@@ -119,18 +150,38 @@ def task_graph(
119
150
  key_name = get_key_name(ir)
120
151
  partition_count = partition_info[ir].count
121
152
  if partition_count > 1:
122
- graph[key_name] = (_concat, list(partition_info[ir].keys(ir)))
153
+ graph[key_name] = (_concat, *partition_info[ir].keys(ir))
123
154
  return graph, key_name
124
155
  else:
125
156
  return graph, (key_name, 0)
126
157
 
127
158
 
159
+ def get_client():
160
+ """Get appropriate Dask client or scheduler."""
161
+ SerializerManager.register_serialize()
162
+
163
+ try: # pragma: no cover; block depends on executor type and Distributed cluster
164
+ from distributed import get_client
165
+
166
+ client = get_client()
167
+ SerializerManager.run_on_cluster(client)
168
+ except (
169
+ ImportError,
170
+ ValueError,
171
+ ): # pragma: no cover; block depends on Dask local scheduler
172
+ from dask import get
173
+
174
+ return get
175
+ else: # pragma: no cover; block depends on executor type and Distributed cluster
176
+ return client.get
177
+
178
+
128
179
  def evaluate_dask(ir: IR) -> DataFrame:
129
180
  """Evaluate an IR graph with Dask."""
130
- from dask import get
131
-
132
181
  ir, partition_info = lower_ir_graph(ir)
133
182
 
183
+ get = get_client()
184
+
134
185
  graph, key = task_graph(ir, partition_info)
135
186
  return get(graph, key)
136
187
 
@@ -0,0 +1,194 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Shuffle Logic."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import operator
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ import pyarrow as pa
11
+
12
+ import pylibcudf as plc
13
+
14
+ from cudf_polars.containers import DataFrame
15
+ from cudf_polars.dsl.ir import IR
16
+ from cudf_polars.experimental.base import _concat, get_key_name
17
+ from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node
18
+
19
+ if TYPE_CHECKING:
20
+ from collections.abc import MutableMapping
21
+
22
+ from cudf_polars.dsl.expr import NamedExpr
23
+ from cudf_polars.experimental.dispatch import LowerIRTransformer
24
+ from cudf_polars.experimental.parallel import PartitionInfo
25
+ from cudf_polars.typing import Schema
26
+ from cudf_polars.utils.config import ConfigOptions
27
+
28
+
29
+ class Shuffle(IR):
30
+ """
31
+ Shuffle multi-partition data.
32
+
33
+ Notes
34
+ -----
35
+ Only hash-based partitioning is supported (for now).
36
+ """
37
+
38
+ __slots__ = ("config_options", "keys")
39
+ _non_child = ("schema", "keys", "config_options")
40
+ keys: tuple[NamedExpr, ...]
41
+ """Keys to shuffle on."""
42
+ config_options: ConfigOptions
43
+ """Configuration options."""
44
+
45
+ def __init__(
46
+ self,
47
+ schema: Schema,
48
+ keys: tuple[NamedExpr, ...],
49
+ config_options: ConfigOptions,
50
+ df: IR,
51
+ ):
52
+ self.schema = schema
53
+ self.keys = keys
54
+ self.config_options = config_options
55
+ self._non_child_args = (schema, keys, config_options)
56
+ self.children = (df,)
57
+
58
+ @classmethod
59
+ def do_evaluate(
60
+ cls,
61
+ schema: Schema,
62
+ keys: tuple[NamedExpr, ...],
63
+ config_options: ConfigOptions,
64
+ df: DataFrame,
65
+ ): # pragma: no cover
66
+ """Evaluate and return a dataframe."""
67
+ # Single-partition Shuffle evaluation is a no-op
68
+ return df
69
+
70
+
71
+ def _partition_dataframe(
72
+ df: DataFrame,
73
+ keys: tuple[NamedExpr, ...],
74
+ count: int,
75
+ ) -> dict[int, DataFrame]:
76
+ """
77
+ Partition an input DataFrame for shuffling.
78
+
79
+ Notes
80
+ -----
81
+ This utility only supports hash partitioning (for now).
82
+
83
+ Parameters
84
+ ----------
85
+ df
86
+ DataFrame to partition.
87
+ keys
88
+ Shuffle key(s).
89
+ count
90
+ Total number of output partitions.
91
+
92
+ Returns
93
+ -------
94
+ A dictionary mapping between int partition indices and
95
+ DataFrame fragments.
96
+ """
97
+ # Hash the specified keys to calculate the output
98
+ # partition for each row
99
+ partition_map = plc.binaryop.binary_operation(
100
+ plc.hashing.murmurhash3_x86_32(
101
+ DataFrame([expr.evaluate(df) for expr in keys]).table
102
+ ),
103
+ plc.interop.from_arrow(pa.scalar(count, type="uint32")),
104
+ plc.binaryop.BinaryOperator.PYMOD,
105
+ plc.types.DataType(plc.types.TypeId.UINT32),
106
+ )
107
+
108
+ # Apply partitioning
109
+ t, offsets = plc.partitioning.partition(
110
+ df.table,
111
+ partition_map,
112
+ count,
113
+ )
114
+
115
+ # Split and return the partitioned result
116
+ return {
117
+ i: DataFrame.from_table(
118
+ split,
119
+ df.column_names,
120
+ )
121
+ for i, split in enumerate(plc.copying.split(t, offsets[1:-1]))
122
+ }
123
+
124
+
125
+ def _simple_shuffle_graph(
126
+ name_in: str,
127
+ name_out: str,
128
+ keys: tuple[NamedExpr, ...],
129
+ count_in: int,
130
+ count_out: int,
131
+ ) -> MutableMapping[Any, Any]:
132
+ """Make a simple all-to-all shuffle graph."""
133
+ split_name = f"split-{name_out}"
134
+ inter_name = f"inter-{name_out}"
135
+
136
+ graph: MutableMapping[Any, Any] = {}
137
+ for part_out in range(count_out):
138
+ _concat_list = []
139
+ for part_in in range(count_in):
140
+ graph[(split_name, part_in)] = (
141
+ _partition_dataframe,
142
+ (name_in, part_in),
143
+ keys,
144
+ count_out,
145
+ )
146
+ _concat_list.append((inter_name, part_out, part_in))
147
+ graph[_concat_list[-1]] = (
148
+ operator.getitem,
149
+ (split_name, part_in),
150
+ part_out,
151
+ )
152
+ graph[(name_out, part_out)] = (_concat, *_concat_list)
153
+ return graph
154
+
155
+
156
+ @lower_ir_node.register(Shuffle)
157
+ def _(
158
+ ir: Shuffle, rec: LowerIRTransformer
159
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
160
+ # Simple lower_ir_node handling for the default hash-based shuffle.
161
+ # More-complex logic (e.g. joining and sorting) should
162
+ # be handled separately.
163
+ from cudf_polars.experimental.parallel import PartitionInfo
164
+
165
+ (child,) = ir.children
166
+
167
+ new_child, pi = rec(child)
168
+ if pi[new_child].count == 1 or ir.keys == pi[new_child].partitioned_on:
169
+ # Already shuffled
170
+ return new_child, pi
171
+ new_node = ir.reconstruct([new_child])
172
+ pi[new_node] = PartitionInfo(
173
+ # Default shuffle preserves partition count
174
+ count=pi[new_child].count,
175
+ # Add partitioned_on info
176
+ partitioned_on=ir.keys,
177
+ )
178
+ return new_node, pi
179
+
180
+
181
+ @generate_ir_tasks.register(Shuffle)
182
+ def _(
183
+ ir: Shuffle, partition_info: MutableMapping[IR, PartitionInfo]
184
+ ) -> MutableMapping[Any, Any]:
185
+ # Use a simple all-to-all shuffle graph.
186
+
187
+ # TODO: Optionally use rapidsmp.
188
+ return _simple_shuffle_graph(
189
+ get_key_name(ir.children[0]),
190
+ get_key_name(ir),
191
+ ir.keys,
192
+ partition_info[ir.children[0]].count,
193
+ partition_info[ir].count,
194
+ )