cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. cudf_polars/GIT_COMMIT +1 -1
  2. cudf_polars/VERSION +1 -1
  3. cudf_polars/callback.py +60 -15
  4. cudf_polars/containers/column.py +137 -77
  5. cudf_polars/containers/dataframe.py +123 -34
  6. cudf_polars/containers/datatype.py +134 -13
  7. cudf_polars/dsl/expr.py +0 -2
  8. cudf_polars/dsl/expressions/aggregation.py +80 -28
  9. cudf_polars/dsl/expressions/binaryop.py +34 -14
  10. cudf_polars/dsl/expressions/boolean.py +110 -37
  11. cudf_polars/dsl/expressions/datetime.py +59 -30
  12. cudf_polars/dsl/expressions/literal.py +11 -5
  13. cudf_polars/dsl/expressions/rolling.py +460 -119
  14. cudf_polars/dsl/expressions/selection.py +9 -8
  15. cudf_polars/dsl/expressions/slicing.py +1 -1
  16. cudf_polars/dsl/expressions/string.py +256 -114
  17. cudf_polars/dsl/expressions/struct.py +19 -7
  18. cudf_polars/dsl/expressions/ternary.py +33 -3
  19. cudf_polars/dsl/expressions/unary.py +126 -64
  20. cudf_polars/dsl/ir.py +1053 -350
  21. cudf_polars/dsl/to_ast.py +30 -13
  22. cudf_polars/dsl/tracing.py +194 -0
  23. cudf_polars/dsl/translate.py +307 -107
  24. cudf_polars/dsl/utils/aggregations.py +43 -30
  25. cudf_polars/dsl/utils/reshape.py +14 -2
  26. cudf_polars/dsl/utils/rolling.py +12 -8
  27. cudf_polars/dsl/utils/windows.py +35 -20
  28. cudf_polars/experimental/base.py +55 -2
  29. cudf_polars/experimental/benchmarks/pdsds.py +12 -126
  30. cudf_polars/experimental/benchmarks/pdsh.py +792 -2
  31. cudf_polars/experimental/benchmarks/utils.py +596 -39
  32. cudf_polars/experimental/dask_registers.py +47 -20
  33. cudf_polars/experimental/dispatch.py +9 -3
  34. cudf_polars/experimental/distinct.py +2 -0
  35. cudf_polars/experimental/explain.py +15 -2
  36. cudf_polars/experimental/expressions.py +30 -15
  37. cudf_polars/experimental/groupby.py +25 -4
  38. cudf_polars/experimental/io.py +156 -124
  39. cudf_polars/experimental/join.py +53 -23
  40. cudf_polars/experimental/parallel.py +68 -19
  41. cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
  42. cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
  43. cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
  44. cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
  45. cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
  46. cudf_polars/experimental/rapidsmpf/core.py +488 -0
  47. cudf_polars/experimental/rapidsmpf/dask.py +172 -0
  48. cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
  49. cudf_polars/experimental/rapidsmpf/io.py +696 -0
  50. cudf_polars/experimental/rapidsmpf/join.py +322 -0
  51. cudf_polars/experimental/rapidsmpf/lower.py +74 -0
  52. cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
  53. cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
  54. cudf_polars/experimental/rapidsmpf/union.py +115 -0
  55. cudf_polars/experimental/rapidsmpf/utils.py +374 -0
  56. cudf_polars/experimental/repartition.py +9 -2
  57. cudf_polars/experimental/select.py +177 -14
  58. cudf_polars/experimental/shuffle.py +46 -12
  59. cudf_polars/experimental/sort.py +100 -26
  60. cudf_polars/experimental/spilling.py +1 -1
  61. cudf_polars/experimental/statistics.py +24 -5
  62. cudf_polars/experimental/utils.py +25 -7
  63. cudf_polars/testing/asserts.py +13 -8
  64. cudf_polars/testing/io.py +2 -1
  65. cudf_polars/testing/plugin.py +93 -17
  66. cudf_polars/typing/__init__.py +86 -32
  67. cudf_polars/utils/config.py +473 -58
  68. cudf_polars/utils/cuda_stream.py +70 -0
  69. cudf_polars/utils/versions.py +5 -4
  70. cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
  71. cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
  72. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
  73. cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
  74. cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
  75. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
  76. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@
5
5
  from __future__ import annotations
6
6
 
7
7
  import itertools
8
+ from functools import partial
8
9
  from typing import TYPE_CHECKING, Any
9
10
 
10
11
  from cudf_polars.dsl.ir import IR
@@ -15,6 +16,7 @@ from cudf_polars.experimental.utils import _concat
15
16
  if TYPE_CHECKING:
16
17
  from collections.abc import MutableMapping
17
18
 
19
+ from cudf_polars.dsl.ir import IRExecutionContext
18
20
  from cudf_polars.experimental.parallel import PartitionInfo
19
21
  from cudf_polars.typing import Schema
20
22
 
@@ -44,7 +46,9 @@ class Repartition(IR):
44
46
 
45
47
  @generate_ir_tasks.register(Repartition)
46
48
  def _(
47
- ir: Repartition, partition_info: MutableMapping[IR, PartitionInfo]
49
+ ir: Repartition,
50
+ partition_info: MutableMapping[IR, PartitionInfo],
51
+ context: IRExecutionContext,
48
52
  ) -> MutableMapping[Any, Any]:
49
53
  # Repartition an IR node.
50
54
  # Only supports rapartitioning to fewer (for now).
@@ -64,6 +68,9 @@ def _(
64
68
  offsets = [0, *itertools.accumulate(n + (i < remainder) for i in range(count_out))]
65
69
  child_keys = tuple(partition_info[child].keys(child))
66
70
  return {
67
- (key_name, i): (_concat, *child_keys[offsets[i] : offsets[i + 1]])
71
+ (key_name, i): (
72
+ partial(_concat, context=context),
73
+ *child_keys[offsets[i] : offsets[i + 1]],
74
+ )
68
75
  for i in range(count_out)
69
76
  }
@@ -4,6 +4,7 @@
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
+ from collections import defaultdict
7
8
  from typing import TYPE_CHECKING
8
9
 
9
10
  import polars as pl
@@ -12,20 +13,23 @@ from cudf_polars.dsl import expr
12
13
  from cudf_polars.dsl.expr import Col, Len
13
14
  from cudf_polars.dsl.ir import Empty, HConcat, Scan, Select, Union
14
15
  from cudf_polars.dsl.traversal import traversal
16
+ from cudf_polars.dsl.utils.naming import unique_names
15
17
  from cudf_polars.experimental.base import ColumnStat, PartitionInfo
16
18
  from cudf_polars.experimental.dispatch import lower_ir_node
17
19
  from cudf_polars.experimental.expressions import decompose_expr_graph
20
+ from cudf_polars.experimental.repartition import Repartition
18
21
  from cudf_polars.experimental.utils import (
19
22
  _contains_unsupported_fill_strategy,
20
23
  _lower_ir_fallback,
21
24
  )
22
25
 
23
26
  if TYPE_CHECKING:
24
- from collections.abc import MutableMapping
27
+ from collections.abc import MutableMapping, Sequence
25
28
 
26
29
  from cudf_polars.dsl.ir import IR
27
30
  from cudf_polars.experimental.parallel import LowerIRTransformer
28
31
  from cudf_polars.experimental.statistics import StatsCollector
32
+ from cudf_polars.typing import Schema
29
33
  from cudf_polars.utils.config import ConfigOptions
30
34
 
31
35
 
@@ -74,7 +78,10 @@ def decompose_select(
74
78
  decompose_expr_graph
75
79
  """
76
80
  # Collect partial selections
77
- selections = []
81
+ selections: list[Select] = []
82
+ name_generator = unique_names(
83
+ (*(ne.name for ne in select_ir.exprs), *input_ir.schema.keys())
84
+ )
78
85
  for ne in select_ir.exprs:
79
86
  # Decompose this partial expression
80
87
  new_ne, partial_input_ir, _partition_info = decompose_expr_graph(
@@ -84,6 +91,7 @@ def decompose_select(
84
91
  config_options,
85
92
  stats.row_count.get(select_ir.children[0], ColumnStat[int](None)),
86
93
  stats.column_stats.get(select_ir.children[0], {}),
94
+ name_generator,
87
95
  )
88
96
  pi = _partition_info[partial_input_ir]
89
97
  partial_input_ir = Select(
@@ -97,7 +105,11 @@ def decompose_select(
97
105
  selections.append(partial_input_ir)
98
106
 
99
107
  # Concatenate partial selections
100
- new_ir: HConcat | Select
108
+ new_ir: Select | HConcat
109
+ selections, partition_info = _fuse_simple_reductions(
110
+ selections,
111
+ partition_info,
112
+ )
101
113
  if len(selections) > 1:
102
114
  new_ir = HConcat(
103
115
  select_ir.schema,
@@ -113,6 +125,151 @@ def decompose_select(
113
125
  return new_ir, partition_info
114
126
 
115
127
 
128
+ def _fuse_simple_reductions(
129
+ decomposed_select_irs: Sequence[Select],
130
+ pi: MutableMapping[IR, PartitionInfo],
131
+ ) -> tuple[list[Select], MutableMapping[IR, PartitionInfo]]:
132
+ """
133
+ Fuse simple reductions that are part of the same Select node.
134
+
135
+ Parameters
136
+ ----------
137
+ decomposed_select_irs
138
+ The decomposed Select nodes.
139
+ pi
140
+ Partition information.
141
+
142
+ Returns
143
+ -------
144
+ fused_select_irs, pi
145
+ The new Select nodes, and the updated partition information.
146
+ """
147
+ # After a Select node is decomposed, it will be broken into
148
+ # one or more Select nodes that each target a different
149
+ # named expression. In some cases, one or more of these
150
+ # decomposed select nodes will be simple reductions that
151
+ # *should* be performed at the same time. Each "simple"
152
+ # reduction will have the following pattern:
153
+ #
154
+ # # Partition-wise column selection (select_c)
155
+ # Select(
156
+ # # Outer Agg selection (select_b)
157
+ # Select(
158
+ # # Repartition to 1 (repartition)
159
+ # Repartition(
160
+ # # Inner Agg selection (select_a)
161
+ # Select(
162
+ # ...
163
+ # )
164
+ # )
165
+ # )
166
+ # )
167
+ #
168
+ # We need to fuse these simple reductions together to
169
+ # avoid unnecessary memory pressure.
170
+
171
+ # If there is only one decomposed_select_ir, return it
172
+ if len(decomposed_select_irs) == 1:
173
+ return list(decomposed_select_irs), pi
174
+
175
+ fused_select_c_exprs = []
176
+ fused_select_c_schema: Schema = {}
177
+
178
+ # Find reduction groups
179
+ reduction_groups: defaultdict[IR, list[Select]] = defaultdict(list)
180
+ for select_c in decomposed_select_irs:
181
+ # Final expressions and schema must be included in
182
+ # the fused select_c node even if this specific
183
+ # selection is not a simple reduction.
184
+ fused_select_c_exprs.extend(list(select_c.exprs))
185
+ fused_select_c_schema |= select_c.schema
186
+
187
+ if (
188
+ isinstance((select_b := select_c.children[0]), Select)
189
+ and pi[select_b].count == 1
190
+ and isinstance(repartition := select_b.children[0], Repartition)
191
+ and pi[repartition].count == 1
192
+ and isinstance(select_a := repartition.children[0], Select)
193
+ ):
194
+ # We have a simple reduction that may be
195
+ # fused with other simple reductions
196
+ # sharing the same root.
197
+ reduction_root = select_a.children[0]
198
+ reduction_groups[reduction_root].append(select_c)
199
+ else:
200
+ # Not a simple reduction.
201
+ # This selection becomes it own "group".
202
+ reduction_groups[select_c].append(select_c)
203
+
204
+ new_decomposed_select_irs: list[IR] = []
205
+ for root_ir, group in reduction_groups.items():
206
+ if len(group) > 1:
207
+ # Fuse simple-aggregation group
208
+ fused_select_b_exprs = []
209
+ fused_select_a_exprs = []
210
+ fused_select_b_schema: Schema = {}
211
+ fused_select_a_schema: Schema = {}
212
+ for select_c in group:
213
+ select_b = select_c.children[0]
214
+ assert isinstance(select_b, Select), (
215
+ f"Expected Select, got {type(select_b)}"
216
+ )
217
+ fused_select_b_exprs.extend(list(select_b.exprs))
218
+ fused_select_b_schema |= select_b.schema
219
+ select_a = select_b.children[0].children[0]
220
+ assert isinstance(select_a, Select), (
221
+ f"Expected Select, got {type(select_a)}"
222
+ )
223
+ fused_select_a_exprs.extend(list(select_a.exprs))
224
+ fused_select_a_schema |= select_a.schema
225
+ fused_select_a = Select(
226
+ fused_select_a_schema,
227
+ fused_select_a_exprs,
228
+ True, # noqa: FBT003
229
+ root_ir,
230
+ )
231
+ pi[fused_select_a] = PartitionInfo(count=pi[root_ir].count)
232
+ fused_repartition = Repartition(fused_select_a_schema, fused_select_a)
233
+ pi[fused_repartition] = PartitionInfo(count=1)
234
+ fused_select_b = Select(
235
+ fused_select_b_schema,
236
+ fused_select_b_exprs,
237
+ True, # noqa: FBT003
238
+ fused_repartition,
239
+ )
240
+ pi[fused_select_b] = PartitionInfo(count=1)
241
+ new_decomposed_select_irs.append(fused_select_b)
242
+ else:
243
+ # Nothing to fuse for this group
244
+ new_decomposed_select_irs.append(group[0])
245
+
246
+ # If any aggregations were fused, we must concatenate
247
+ # the results and apply the final (fused) "c" selection,
248
+ # otherwise we may mess up the ordering of the columns.
249
+ if len(new_decomposed_select_irs) < len(decomposed_select_irs):
250
+ # Compute schema from actual children (intermediate columns)
251
+ hconcat_schema: Schema = {}
252
+ for ir in new_decomposed_select_irs:
253
+ hconcat_schema |= ir.schema
254
+ new_hconcat = HConcat(
255
+ hconcat_schema,
256
+ True, # noqa: FBT003
257
+ *new_decomposed_select_irs,
258
+ )
259
+ count = max(pi[c].count for c in new_decomposed_select_irs)
260
+ pi[new_hconcat] = PartitionInfo(count=count)
261
+ fused_select_c = Select(
262
+ fused_select_c_schema,
263
+ fused_select_c_exprs,
264
+ True, # noqa: FBT003
265
+ new_hconcat,
266
+ )
267
+ pi[fused_select_c] = PartitionInfo(count=count)
268
+ return [fused_select_c], pi
269
+
270
+ return list(decomposed_select_irs), pi
271
+
272
+
116
273
  @lower_ir_node.register(Select)
117
274
  def _(
118
275
  ir: Select, rec: LowerIRTransformer
@@ -130,21 +287,27 @@ def _(
130
287
  "for multiple partitions; falling back to in-memory evaluation."
131
288
  ),
132
289
  )
133
- if (
134
- pi.count == 1
135
- and Select._is_len_expr(ir.exprs)
136
- and isinstance(child, Union)
137
- and len(child.children) == 1
138
- and isinstance(child.children[0], Scan)
139
- and child.children[0].predicate is None
140
- ):
290
+
291
+ scan_child: Scan | None = None
292
+ if pi.count == 1 and Select._is_len_expr(ir.exprs):
293
+ if (
294
+ isinstance(child, Union)
295
+ and len(child.children) == 1
296
+ and isinstance(child.children[0], Scan)
297
+ ):
298
+ # Task engine case
299
+ scan_child = child.children[0]
300
+ elif isinstance(child, Scan): # pragma: no cover; Requires rapidsmpf runtime
301
+ # RapidsMPF case
302
+ scan_child = child
303
+
304
+ if scan_child and scan_child.predicate is None:
141
305
  # Special Case: Fast count.
142
- scan = child.children[0]
143
- count = scan.fast_count()
306
+ count = scan_child.fast_count()
144
307
  dtype = ir.exprs[0].value.dtype
145
308
 
146
309
  lit_expr = expr.LiteralColumn(
147
- dtype, pl.Series(values=[count], dtype=dtype.polars)
310
+ dtype, pl.Series(values=[count], dtype=dtype.polars_type)
148
311
  )
149
312
  named_expr = expr.NamedExpr(ir.exprs[0].name or "len", lit_expr)
150
313
 
@@ -5,6 +5,7 @@
5
5
  from __future__ import annotations
6
6
 
7
7
  import operator
8
+ from functools import partial
8
9
  from typing import TYPE_CHECKING, Any, Concatenate, Literal, TypeVar, TypedDict
9
10
 
10
11
  import pylibcudf as plc
@@ -13,16 +14,19 @@ from rmm.pylibrmm.stream import DEFAULT_STREAM
13
14
  from cudf_polars.containers import DataFrame
14
15
  from cudf_polars.dsl.expr import Col
15
16
  from cudf_polars.dsl.ir import IR
16
- from cudf_polars.dsl.tracing import nvtx_annotate_cudf_polars
17
+ from cudf_polars.dsl.tracing import log_do_evaluate, nvtx_annotate_cudf_polars
17
18
  from cudf_polars.experimental.base import get_key_name
18
19
  from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node
19
20
  from cudf_polars.experimental.utils import _concat
21
+ from cudf_polars.utils.config import ShufflerInsertionMethod
22
+ from cudf_polars.utils.cuda_stream import get_dask_cuda_stream
20
23
 
21
24
  if TYPE_CHECKING:
22
25
  from collections.abc import Callable, MutableMapping, Sequence
23
26
 
24
27
  from cudf_polars.containers import DataType
25
28
  from cudf_polars.dsl.expr import NamedExpr
29
+ from cudf_polars.dsl.ir import IRExecutionContext
26
30
  from cudf_polars.experimental.dispatch import LowerIRTransformer
27
31
  from cudf_polars.experimental.parallel import PartitionInfo
28
32
  from cudf_polars.typing import Schema
@@ -40,6 +44,7 @@ class ShuffleOptions(TypedDict):
40
44
  column_names: Sequence[str]
41
45
  dtypes: Sequence[DataType]
42
46
  cluster_kind: Literal["dask", "single"]
47
+ shuffler_insertion_method: ShufflerInsertionMethod
43
48
 
44
49
 
45
50
  # Experimental rapidsmpf shuffler integration
@@ -77,7 +82,14 @@ class RMPFIntegration: # pragma: no cover
77
82
  br=context.br,
78
83
  stream=DEFAULT_STREAM,
79
84
  )
80
- shuffler.insert_chunks(packed_inputs)
85
+
86
+ if (
87
+ options["shuffler_insertion_method"]
88
+ == ShufflerInsertionMethod.CONCAT_INSERT
89
+ ):
90
+ shuffler.concat_insert(packed_inputs)
91
+ else:
92
+ shuffler.insert_chunks(packed_inputs)
81
93
 
82
94
  @staticmethod
83
95
  @nvtx_annotate_cudf_polars(message="RMPFIntegration.extract_partition")
@@ -116,6 +128,7 @@ class RMPFIntegration: # pragma: no cover
116
128
  ),
117
129
  column_names,
118
130
  dtypes,
131
+ get_dask_cuda_stream(),
119
132
  )
120
133
 
121
134
 
@@ -129,33 +142,44 @@ class Shuffle(IR):
129
142
  `ShuffleSorted` for sorting-based shuffling.
130
143
  """
131
144
 
132
- __slots__ = ("keys", "shuffle_method")
133
- _non_child = ("schema", "keys", "shuffle_method")
145
+ __slots__ = ("keys", "shuffle_method", "shuffler_insertion_method")
146
+ _non_child = ("schema", "keys", "shuffle_method", "shuffler_insertion_method")
134
147
  keys: tuple[NamedExpr, ...]
135
148
  """Keys to shuffle on."""
136
149
  shuffle_method: ShuffleMethod
137
150
  """Shuffle method to use."""
151
+ shuffler_insertion_method: ShufflerInsertionMethod
152
+ """Insertion method for rapidsmpf shuffler."""
138
153
 
139
154
  def __init__(
140
155
  self,
141
156
  schema: Schema,
142
157
  keys: tuple[NamedExpr, ...],
143
158
  shuffle_method: ShuffleMethod,
159
+ shuffler_insertion_method: ShufflerInsertionMethod,
144
160
  df: IR,
145
161
  ):
146
162
  self.schema = schema
147
163
  self.keys = keys
148
164
  self.shuffle_method = shuffle_method
149
- self._non_child_args = (schema, keys, shuffle_method)
165
+ self.shuffler_insertion_method = shuffler_insertion_method
166
+ self._non_child_args = (schema, keys, shuffle_method, shuffler_insertion_method)
150
167
  self.children = (df,)
151
168
 
152
- @classmethod
169
+ # the type-ignore is for
170
+ # Argument 1 to "log_do_evaluate" has incompatible type "Callable[[type[Shuffle], <snip>]"
171
+ # expected Callable[[type[IR], <snip>]
172
+ # But Shuffle is a subclass of IR, so this is fine.
173
+ @classmethod # type: ignore[arg-type]
174
+ @log_do_evaluate
153
175
  def do_evaluate(
154
176
  cls,
155
177
  schema: Schema,
156
178
  keys: tuple[NamedExpr, ...],
157
179
  shuffle_method: ShuffleMethod,
158
180
  df: DataFrame,
181
+ *,
182
+ context: IRExecutionContext,
159
183
  ) -> DataFrame: # pragma: no cover
160
184
  """Evaluate and return a dataframe."""
161
185
  # Single-partition Shuffle evaluation is a no-op
@@ -201,11 +225,15 @@ def _hash_partition_dataframe(
201
225
  # partition for each row
202
226
  partition_map = plc.binaryop.binary_operation(
203
227
  plc.hashing.murmurhash3_x86_32(
204
- DataFrame([expr.evaluate(df) for expr in on]).table
228
+ DataFrame([expr.evaluate(df) for expr in on], stream=df.stream).table,
229
+ stream=df.stream,
230
+ ),
231
+ plc.Scalar.from_py(
232
+ partition_count, plc.DataType(plc.TypeId.UINT32), stream=df.stream
205
233
  ),
206
- plc.Scalar.from_py(partition_count, plc.DataType(plc.TypeId.UINT32)),
207
234
  plc.binaryop.BinaryOperator.PYMOD,
208
235
  plc.types.DataType(plc.types.TypeId.UINT32),
236
+ stream=df.stream,
209
237
  )
210
238
 
211
239
  # Apply partitioning
@@ -213,6 +241,7 @@ def _hash_partition_dataframe(
213
241
  df.table,
214
242
  partition_map,
215
243
  partition_count,
244
+ stream=df.stream,
216
245
  )
217
246
  splits = offsets[1:-1]
218
247
 
@@ -222,8 +251,9 @@ def _hash_partition_dataframe(
222
251
  split,
223
252
  df.column_names,
224
253
  df.dtypes,
254
+ df.stream,
225
255
  )
226
- for i, split in enumerate(plc.copying.split(t, splits))
256
+ for i, split in enumerate(plc.copying.split(t, splits, stream=df.stream))
227
257
  }
228
258
 
229
259
 
@@ -242,6 +272,7 @@ def _simple_shuffle_graph(
242
272
  ],
243
273
  options: OPT_T,
244
274
  *other: Any,
275
+ context: IRExecutionContext,
245
276
  ) -> MutableMapping[Any, Any]:
246
277
  """Make a simple all-to-all shuffle graph."""
247
278
  split_name = f"split-{name_out}"
@@ -265,7 +296,7 @@ def _simple_shuffle_graph(
265
296
  (split_name, part_in),
266
297
  part_out,
267
298
  )
268
- graph[(name_out, part_out)] = (_concat, *_concat_list)
299
+ graph[(name_out, part_out)] = (partial(_concat, context=context), *_concat_list)
269
300
  return graph
270
301
 
271
302
 
@@ -296,7 +327,9 @@ def _(
296
327
 
297
328
  @generate_ir_tasks.register(Shuffle)
298
329
  def _(
299
- ir: Shuffle, partition_info: MutableMapping[IR, PartitionInfo]
330
+ ir: Shuffle,
331
+ partition_info: MutableMapping[IR, PartitionInfo],
332
+ context: IRExecutionContext,
300
333
  ) -> MutableMapping[Any, Any]:
301
334
  # Extract "shuffle_method" configuration
302
335
  shuffle_method = ir.shuffle_method
@@ -331,6 +364,7 @@ def _(
331
364
  "column_names": list(ir.schema.keys()),
332
365
  "dtypes": list(ir.schema.values()),
333
366
  "cluster_kind": cluster_kind,
367
+ "shuffler_insertion_method": ir.shuffler_insertion_method,
334
368
  },
335
369
  )
336
370
  except ValueError as err:
@@ -343,7 +377,7 @@ def _(
343
377
  ) from err
344
378
 
345
379
  # Simple task-based fall-back
346
- return _simple_shuffle_graph(
380
+ return partial(_simple_shuffle_graph, context=context)(
347
381
  get_key_name(ir.children[0]),
348
382
  get_key_name(ir),
349
383
  partition_info[ir.children[0]].count,